src

Ψπ (psipy): Symbolic-Numerical Toolkit for PDEs and Hamiltonian Mechanics

Overview

Welcome to psipy, a comprehensive Python ecosystem designed to bridge the gap between formal symbolic mathematics (via SymPy) and high-performance numerical simulation (via NumPy/SciPy). This library provides a unified framework for defining, analyzing, solving, and visualizing complex problems in:

  • Partial Differential Equations (PDEs)
  • Pseudo-Differential Operators (ΨDOs)
  • Hamiltonian and Lagrangian Mechanics
  • Semiclassical and Microlocal Analysis

The core philosophy is to allow users to move seamlessly from a formal symbolic definition—such as a Lagrangian, a Hamiltonian from the included catalog, or a PDE written in SymPy—to a robust numerical analysis, such as solving the PDE's evolution, visualizing its phase-space geometry, or computing its semiclassical spectrum.

Core Components

The psipy ecosystem is composed of several powerful, interoperable modules:

  • PDESolver: The main numerical engine. It parses symbolic PDEs and solves 1D/2D, linear/nonlinear, time-dependent or stationary equations. It uses spectral (FFT) methods with high-order exponential integrators (like ETD-RK4) for robust time evolution.

  • PseudoDifferentialOperator: A complete symbolic and numerical framework for Pseudo-Differential Operators (ΨDOs). It supports symbolic calculus (composition, commutators, adjoints) and microlocal analysis (ellipticity, characteristic sets), bridging formal definitions with numerical evaluation on grids.

  • LagrangianHamiltonianConverter & HamiltonianSymbolicConverter: A symbolic toolkit for analytical mechanics. It performs purely symbolic Legendre transforms (L ↔ H) and can automatically generate formal symbolic PDEs (e.g., Schrödinger, Wave) from any given Hamiltonian symbol.

  • HamiltonianCatalog: A vast, curated, and searchable symbolic database of over 500 Hamiltonian systems. It spans classical mechanics, quantum chaos, biophysics, and more, providing a rich testbed for research and education.

  • SymbolGeometry: A comprehensive analysis and visualization suite for 1D Hamiltonian systems. It connects classical geometry to quantum spectra by computing classical trajectories, periodic orbits, and the semiclassical energy spectrum via the Gutzwiller trace formula and EBK quantization.

  • SymbolGeometry2D: An advanced 2D analysis toolkit for visualizing dynamical systems. It performs rigorous caustic detection by tracking the full 4x4 Jacobian, generates Poincaré sections, and analyzes KAM tori, providing a deep dive into 2D phase space geometry.

Typical Workflow

A common use case involves combining all modules:

  1. Select a System: Fetch a complex Hamiltonian (e.g., "henon_heiles") from the HamiltonianCatalog.

  2. Formulate the PDE: Use SymPhysics to automatically generate the corresponding symbolic Schrödinger equation.

  3. Analyze Geometry: Pass the Hamiltonian symbol to SymbolGeometry2D to visualize its classical trajectories, Poincaré sections, and chaotic regions.

  4. Solve Dynamics: Pass the symbolic PDE to the PDESolver to simulate the quantum wave function's evolution in time.

Example: Solving a Pseudo-Differential PDE

This example defines a 1D Schrödinger-type equation with a non-local, relativistic kinetic term, i ∂ₜ u = √(1 - ∂ₓ²) u.

from solver import *

# 1. Define symbolic variables
t, x, xi = symbols('t x xi', real=True)
u = Function('u')

# 2. Define the PDE symbolically
# The symbol for the operator √(1 - ∂ₓ²) is p(ξ) = √(1 + ξ²)
# (using the Fourier convention p(ξ) → op(ξ) → -∂ₓ²)
p_symbol = (1 + xi**2)**(1/2)

# The equation is: i * ∂ₜ u = psiOp(p_symbol) * u
equation = Eq(I * diff(u(t, x), t), psiOp(p_symbol, u(t, x)))

# 3. Create the solver
solver = PDESolver(equation)

# 4. Setup the simulation domain and initial condition
initial_packet = lambda x: np.exp(-(x - np.pi)**2 / 0.5) * np.exp(1j * 5.0 * x)
solver.setup(
    Lx=2 * np.pi, Nx=256,
    Lt=4.0, Nt=1000,
    initial_condition=initial_packet,
    boundary_condition='periodic'
)

# 5. Solve the PDE
solver.solve()

# 6. Animate the solution
ani = solver.animate()
HTML(ani.to_jshtml())
  1"""
  2Ψπ (psipy): Symbolic-Numerical Toolkit for PDEs and Hamiltonian Mechanics
  3========================================================================
  4
  5## Overview
  6
  7Welcome to `psipy`, a comprehensive Python ecosystem designed to bridge the gap
  8between formal symbolic mathematics (via SymPy) and high-performance numerical
  9simulation (via NumPy/SciPy). This library provides a unified framework for
 10defining, analyzing, solving, and visualizing complex problems in:
 11
 12- Partial Differential Equations (PDEs)
 13- Pseudo-Differential Operators (ΨDOs)
 14- Hamiltonian and Lagrangian Mechanics
 15- Semiclassical and Microlocal Analysis
 16
 17The core philosophy is to allow users to move seamlessly from a formal symbolic
 18definition—such as a Lagrangian, a Hamiltonian from the included catalog, or a
 19PDE written in SymPy—to a robust numerical analysis, such as solving the PDE's
 20evolution, visualizing its phase-space geometry, or computing its semiclassical
 21spectrum.
 22
 23## Core Components
 24
 25The `psipy` ecosystem is composed of several powerful, interoperable modules:
 26
 27- **`PDESolver`**: The main numerical engine. It parses symbolic PDEs and solves
 28  1D/2D, linear/nonlinear, time-dependent or stationary equations. It uses spectral
 29  (FFT) methods with high-order exponential integrators (like ETD-RK4) for robust
 30  time evolution.
 31
 32- **`PseudoDifferentialOperator`**: A complete symbolic and numerical framework for Pseudo-Differential
 33  Operators (ΨDOs). It supports symbolic calculus (composition, commutators, adjoints)
 34  and microlocal analysis (ellipticity, characteristic sets), bridging formal definitions
 35  with numerical evaluation on grids.
 36
 37- **`LagrangianHamiltonianConverter` & `HamiltonianSymbolicConverter`**: A symbolic toolkit for analytical mechanics. It performs purely
 38  symbolic Legendre transforms (L ↔ H) and can automatically generate formal symbolic
 39  PDEs (e.g., Schrödinger, Wave) from any given Hamiltonian symbol.
 40
 41- **`HamiltonianCatalog`**: A vast, curated, and searchable symbolic database of
 42  **over 500** Hamiltonian systems. It spans classical mechanics, quantum chaos,
 43  biophysics, and more, providing a rich testbed for research and education.
 44
 45- **`SymbolGeometry`**: A comprehensive analysis and visualization suite for 1D
 46  Hamiltonian systems. It connects classical geometry to quantum spectra by computing
 47  classical trajectories, periodic orbits, and the semiclassical energy spectrum via
 48  the **Gutzwiller trace formula** and **EBK quantization**.
 49
 50- **`SymbolGeometry2D`**: An advanced 2D analysis toolkit for visualizing dynamical
 51  systems. It performs rigorous **caustic detection** by tracking the full 4x4 Jacobian,
 52  generates **Poincaré sections**, and analyzes **KAM tori**, providing a deep dive
 53  into 2D phase space geometry.
 54
 55## Typical Workflow
 56
 57A common use case involves combining all modules:
 58
 591. **Select a System**: Fetch a complex Hamiltonian (e.g., "henon_heiles")
 60   from the `HamiltonianCatalog`.
 61
 622. **Formulate the PDE**: Use `SymPhysics` to automatically generate the
 63   corresponding symbolic Schrödinger equation.
 64
 653. **Analyze Geometry**: Pass the Hamiltonian symbol to `SymbolGeometry2D`
 66   to visualize its classical trajectories, Poincaré sections, and chaotic regions.
 67
 684. **Solve Dynamics**: Pass the symbolic PDE to the `PDESolver` to
 69   simulate the quantum wave function's evolution in time.
 70
 71## Example: Solving a Pseudo-Differential PDE
 72
 73This example defines a 1D Schrödinger-type equation with a non-local,
 74relativistic kinetic term, i ∂ₜ u = √(1 - ∂ₓ²) u.
 75
 76```python
 77from solver import *
 78
 79# 1. Define symbolic variables
 80t, x, xi = symbols('t x xi', real=True)
 81u = Function('u')
 82
 83# 2. Define the PDE symbolically
 84# The symbol for the operator √(1 - ∂ₓ²) is p(ξ) = √(1 + ξ²)
 85# (using the Fourier convention p(ξ) → op(ξ) → -∂ₓ²)
 86p_symbol = (1 + xi**2)**(1/2)
 87
 88# The equation is: i * ∂ₜ u = psiOp(p_symbol) * u
 89equation = Eq(I * diff(u(t, x), t), psiOp(p_symbol, u(t, x)))
 90
 91# 3. Create the solver
 92solver = PDESolver(equation)
 93
 94# 4. Setup the simulation domain and initial condition
 95initial_packet = lambda x: np.exp(-(x - np.pi)**2 / 0.5) * np.exp(1j * 5.0 * x)
 96solver.setup(
 97    Lx=2 * np.pi, Nx=256,
 98    Lt=4.0, Nt=1000,
 99    initial_condition=initial_packet,
100    boundary_condition='periodic'
101)
102
103# 5. Solve the PDE
104solver.solve()
105
106# 6. Animate the solution
107ani = solver.animate()
108HTML(ani.to_jshtml())
109```
110"""
111from importlib.metadata import version
112
113# Imports publics
114from .psiop import *
115from .solver import *
116from .physics import *
117from .geometry_1d import *
118from .geometry_2d import *
119from .hamiltonian_catalog import *
120from .riemannian_1d import *
121from .riemannian_2d import *
122from .symplectic_1d import *
123from .symplectic_2d import *
124from .microlocal_1d import *
125from .microlocal_2d import *
126
127# Version du package
128__version__ = version("psipy")
129
130# Liste des noms exposés par `from psipy import *`
131__all__ = [
132    "PseudoDifferentialOperator",
133    "PDESolver",
134    "LagrangianHamiltonianConverter",
135    "HamiltonianSymbolicConverter",
136    "SymbolGeometry",
137    "SymbolVisualizer",
138    "SpectralAnalysis",
139    "SymbolGeometry2D",
140    "SymbolVisualizer2D",
141    "Utilities2D",
142    # Riemannian 1D
143    'Metric1D',
144    'geodesic_integrator',
145    'laplace_beltrami',
146    
147    # Riemannian 2D
148    'Metric2D',
149    'geodesic_solver',
150    'exponential_map',
151    
152    # Symplectic 1D
153    'SymplecticForm1D',
154    'hamiltonian_flow',
155    'poisson_bracket',
156    
157    # Symplectic 2D
158    'SymplecticForm2D',
159    'hamiltonian_flow_4d',
160    'poincare_section',
161    
162    # Microlocal 1D
163    'characteristic_variety',
164    'bicharacteristic_flow',
165    'wkb_ansatz',
166    'bohr_sommerfeld_quantization',
167    
168    # Microlocal 2D
169    'characteristic_variety_2d',
170    'bichar_flow_2d',
171    'compute_maslov_index',
172]
class PseudoDifferentialOperator:
  28class PseudoDifferentialOperator:
  29    """
  30    Pseudo-differential operator with dynamic symbol evaluation on spatial grids.
  31    Supports both 1D and 2D operators, and can be defined explicitly (symbol mode)
  32    or extracted automatically from symbolic equations (auto mode).
  33
  34    Parameters
  35    ----------
  36    expr : sympy expression
  37        Symbolic expression representing the pseudo-differential symbol.
  38    vars_x : list of sympy symbols
  39        Spatial variables (e.g., [x] for 1D, [x, y] for 2D).
  40    var_u : sympy function, optional
  41        Function u(x, t) used in auto mode to extract the operator symbol.
  42    mode : str, {'symbol', 'auto'}
  43        - 'symbol': directly uses expr as the operator symbol.
  44        - 'auto': computes the symbol automatically by applying expr to exp(i x ξ).
  45
  46    Attributes
  47    ----------
  48    dim : int
  49        Spatial dimension (1 or 2).
  50    fft, ifft : callable
  51        Fast Fourier transform and inverse (scipy.fft or scipy.fft2).
  52    p_func : callable
  53        Evaluated symbol function ready for numerical use.
  54
  55    Notes
  56    -----
  57    - In 'symbol' mode, `expr` should be expressed in terms of spatial variables and frequency variables (ξ, η).
  58    - In 'auto' mode, the symbol is derived by applying the differential expression to a complex exponential.
  59    - Frequency variables are internally named 'xi' and 'eta' for consistency.
  60    - Uses numpy for numerical evaluation and scipy.fft for FFT operations.
  61
  62    Examples
  63    --------
  64    >>> # Example 1: 1D Laplacian operator (symbol mode)
  65    >>> from sympy import symbols
  66    >>> x, xi = symbols('x xi', real=True)
  67    >>> op = PseudoDifferentialOperator(expr=xi**2, vars_x=[x], mode='symbol')
  68
  69    >>> # Example 2: 1D transport operator (auto mode)
  70    >>> from sympy import Function
  71    >>> u = Function('u')
  72    >>> expr = u(x).diff(x)
  73    >>> op = PseudoDifferentialOperator(expr=expr, vars_x=[x], var_u=u(x), mode='auto')
  74    """
  75
  76    def __init__(self, expr, vars_x, var_u=None, mode='symbol'):
  77        self.dim = len(vars_x)
  78        self.mode = mode
  79        self.symbol_cached = None
  80        self.expr = expr
  81        self.vars_x = vars_x
  82
  83        if self.dim == 1:
  84            x, = vars_x
  85            xi_internal = symbols('xi', real=True)
  86            expr = expr.subs(symbols('xi', real=True), xi_internal)
  87            self.fft = partial(fft, workers=FFT_WORKERS)
  88            self.ifft = partial(ifft, workers=FFT_WORKERS)
  89
  90            if mode == 'symbol':
  91                self.p_func = lambdify((x, xi_internal), expr, 'numpy')
  92                self.symbol = expr
  93            elif mode == 'auto':
  94                if var_u is None:
  95                    raise ValueError("var_u must be provided in mode='auto'")
  96                exp_i = exp(I * x * xi_internal)
  97                P_ei = expr.subs(var_u, exp_i)
  98                symbol = simplify(P_ei / exp_i)
  99                symbol = expand(symbol)
 100                self.symbol = symbol
 101                self.p_func = lambdify((x, xi_internal), symbol, 'numpy')
 102            else:
 103                raise ValueError("mode must be 'auto' or 'symbol'")
 104
 105        elif self.dim == 2:
 106            x, y = vars_x
 107            xi_internal, eta_internal = symbols('xi eta', real=True)
 108            expr = expr.subs(symbols('xi', real=True), xi_internal)
 109            expr = expr.subs(symbols('eta', real=True), eta_internal)
 110            self.fft = partial(fft2, workers=FFT_WORKERS)
 111            self.ifft = partial(ifft2, workers=FFT_WORKERS)
 112
 113            if mode == 'symbol':
 114                self.symbol = expr
 115                self.p_func = lambdify((x, y, xi_internal, eta_internal), expr, 'numpy')
 116            elif mode == 'auto':
 117                if var_u is None:
 118                    raise ValueError("var_u must be provided in mode='auto'")
 119                exp_i = exp(I * (x * xi_internal + y * eta_internal))
 120                P_ei = expr.subs(var_u, exp_i)
 121                symbol = simplify(P_ei / exp_i)
 122                symbol = expand(symbol)
 123                self.symbol = symbol
 124                self.p_func = lambdify((x, y, xi_internal, eta_internal), symbol, 'numpy')
 125            else:
 126                raise ValueError("mode must be 'auto' or 'symbol'")
 127
 128        else:
 129            raise NotImplementedError("Only 1D and 2D supported")
 130
 131        if mode == 'auto':
 132            print("\nsymbol = ")
 133            pprint(self.symbol, num_columns=NUM_COLS)
 134        
 135    def evaluate(self, X, Y, KX, KY, cache=True):
 136        """
 137        Evaluate the pseudo-differential operator's symbol on a grid of spatial and frequency coordinates.
 138
 139        The method dynamically selects between 1D and 2D evaluation based on the spatial dimension.
 140        If caching is enabled and a cached symbol exists, it returns the cached result to avoid recomputation.
 141
 142        Parameters
 143        ----------
 144        X, Y : ndarray
 145            Spatial grid coordinates. In 1D, Y is ignored.
 146        KX, KY : ndarray
 147            Frequency grid coordinates. In 1D, KY is ignored.
 148        cache : bool, default=True
 149            If True, stores the computed symbol for reuse in subsequent calls to avoid redundant computation.
 150
 151        Returns
 152        -------
 153        ndarray
 154            Evaluated symbol values over the input grid. Shape matches the input spatial/frequency grids.
 155
 156        Raises
 157        ------
 158        NotImplementedError
 159            If the spatial dimension is not 1D or 2D.
 160        """
 161        if cache and self.symbol_cached is not None:
 162            return self.symbol_cached
 163
 164        if self.dim == 1:
 165            symbol = self.p_func(X, KX)
 166        elif self.dim == 2:
 167            symbol = self.p_func(X, Y, KX, KY)
 168
 169        if cache:
 170            self.symbol_cached = symbol
 171
 172        return symbol
 173
 174    def clear_cache(self):
 175        """
 176        Clear cached symbol evaluations.
 177        """        
 178        self.symbol_cached = None
 179
 180    def apply(self, u, x_grid, kx, boundary_condition='periodic', 
 181              y_grid=None, ky=None, dealiasing_mask=None,
 182              freq_window='gaussian', clamp=1e6, space_window=False):
 183        """
 184        Apply the pseudo-differential operator to the input field u.
 185    
 186        This method dispatches the application of the pseudo-differential operator based on:
 187        
 188        - Whether the symbol is spatially dependent (x/y)
 189        - The boundary condition in use (periodic or dirichlet)
 190    
 191        Supported operations:
 192        
 193        - Constant-coefficient symbols: applied via Fourier multiplication.
 194        - Spatially varying symbols: applied via Kohn–Nirenberg quantization.
 195        - Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.
 196    
 197        Dispatch Logic:\n
 198        if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]\n
 199        elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)\n
 200        elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)\n
 201        
 202        Parameters
 203        ----------
 204        u : ndarray
 205            Function to which the operator is applied
 206        x_grid : ndarray
 207            Spatial grid in x direction
 208        kx : ndarray
 209            Frequency grid in x direction
 210        boundary_condition : str
 211            'periodic' or 'dirichlet'
 212        y_grid : ndarray, optional
 213            Spatial grid in y direction (for 2D)
 214        ky : ndarray, optional
 215            Frequency grid in y direction (for 2D)
 216        dealiasing_mask : ndarray, optional
 217            Dealiasing mask
 218        freq_window : str
 219            Frequency windowing ('gaussian' or 'hann')
 220        clamp : float
 221            Clamp symbol values to [-clamp, clamp]
 222        space_window : bool
 223            Apply spatial windowing
 224            
 225        Returns
 226        -------
 227        ndarray
 228            Result of applying the operator
 229        """
 230        # Check if symbol depends on spatial variables
 231        is_spatial = self._is_spatial_dependent()
 232        
 233        # Case 1: Constant symbol with periodic BC (fast path)
 234        if not is_spatial and boundary_condition == 'periodic':
 235            return self._apply_constant_fft(u, x_grid, kx, y_grid, ky, dealiasing_mask)
 236        
 237        # Case 2: Spatial symbol with periodic BC
 238        elif boundary_condition == 'periodic':
 239            symbol_func = self._get_symbol_func()
 240            return kohn_nirenberg_fft(
 241                u_vals=u,
 242                symbol_func=symbol_func,
 243                x_grid=x_grid,
 244                kx=kx,
 245                fft_func=self.fft,
 246                ifft_func=self.ifft,
 247                dim=self.dim,
 248                y_grid=y_grid,
 249                ky=ky,
 250                freq_window=freq_window,
 251                clamp=clamp,
 252                space_window=space_window
 253            )
 254        
 255        # Case 3: Dirichlet BC (non-periodic)
 256        elif boundary_condition == 'dirichlet':
 257            symbol_func = self._get_symbol_func()
 258            
 259            if self.dim == 1:
 260                return kohn_nirenberg_nonperiodic(
 261                    u_vals=u,
 262                    x_grid=x_grid,
 263                    xi_grid=kx,
 264                    symbol_func=symbol_func,
 265                    freq_window=freq_window,
 266                    clamp=clamp,
 267                    space_window=space_window
 268                )
 269            elif self.dim == 2:
 270                return kohn_nirenberg_nonperiodic(
 271                    u_vals=u,
 272                    x_grid=(x_grid, y_grid),
 273                    xi_grid=(kx, ky),
 274                    symbol_func=symbol_func,
 275                    freq_window=freq_window,
 276                    clamp=clamp,
 277                    space_window=space_window
 278                )
 279        
 280        else:
 281            raise ValueError(f"Invalid boundary condition '{boundary_condition}'")
 282    
 283    def _is_spatial_dependent(self):
 284        """
 285        Check if the symbol depends on spatial variables.
 286        
 287        Returns
 288        -------
 289        bool
 290            True if symbol depends on x (or x, y)
 291        """
 292        if self.dim == 1:
 293            return self.symbol.has(self.vars_x[0])
 294        elif self.dim == 2:
 295            x, y = self.vars_x
 296            return self.symbol.has(x) or self.symbol.has(y)
 297        else:
 298            return False
 299    
 300    def _get_symbol_func(self):
 301        """
 302        Get a lambdified version of the symbol.
 303        
 304        Returns
 305        -------
 306        callable
 307            Lambdified symbol function
 308        """
 309        if self.dim == 1:
 310            x = self.vars_x[0]
 311            xi = symbols('xi', real=True)
 312            return lambdify((x, xi), self.symbol, 'numpy')
 313        elif self.dim == 2:
 314            x, y = self.vars_x
 315            xi, eta = symbols('xi eta', real=True)
 316            return lambdify((x, y, xi, eta), self.symbol, 'numpy')
 317        else:
 318            raise NotImplementedError("Only 1D and 2D supported")
 319    
 320    def _apply_constant_fft(self, u, x_grid, kx, y_grid, ky, dealiasing_mask):
 321        """
 322        Apply a constant-coefficient pseudo-differential operator in Fourier space.
 323
 324        This method assumes the symbol is diagonal in the Fourier basis and acts as a 
 325        multiplication operator. It performs the operation:
 326        
 327            (ψu)(x) = 𝓕⁻¹[ -σ(k) · 𝓕[u](k) ]
 328
 329        where:
 330        - σ(k) is the combined pseudo-differential operator symbol
 331        - 𝓕 denotes the forward Fourier transform
 332        - 𝓕⁻¹ denotes the inverse Fourier transform
 333
 334        The dealiasing mask is applied before returning to physical space.
 335        
 336        Parameters
 337        ----------
 338        u : ndarray
 339            Input function
 340        x_grid : ndarray
 341            Spatial grid (x)
 342        kx : ndarray
 343            Frequency grid (x)
 344        y_grid : ndarray, optional
 345            Spatial grid (y, for 2D)
 346        ky : ndarray, optional
 347            Frequency grid (y, for 2D)
 348        dealiasing_mask : ndarray, optional
 349            Dealiasing mask
 350            
 351        Returns
 352        -------
 353        ndarray
 354            Result
 355        """
 356        u_hat = self.fft(u)
 357        
 358        # Evaluate symbol at grid points
 359        if self.dim == 1:
 360            X_dummy = np.zeros_like(kx)
 361            symbol_vals = self.p_func(X_dummy, kx)
 362        elif self.dim == 2:
 363            KX, KY = np.meshgrid(kx, ky, indexing='ij')
 364            X_dummy = np.zeros_like(KX)
 365            Y_dummy = np.zeros_like(KY)
 366            symbol_vals = self.p_func(X_dummy, Y_dummy, KX, KY)
 367        else:
 368            raise ValueError("Only 1D and 2D supported")
 369        
 370        # Apply symbol
 371        u_hat *= symbol_vals
 372        
 373        # Apply dealiasing
 374        if dealiasing_mask is not None:
 375            u_hat *= dealiasing_mask
 376        
 377        return self.ifft(u_hat)
 378
 379    def principal_symbol(self, order=1):
 380        """
 381        Compute the leading homogeneous component of the pseudo-differential symbol.
 382
 383        This method extracts the principal part of the symbol, which is the dominant 
 384        term under high-frequency asymptotics (|ξ| → ∞). The expansion is performed 
 385        in polar coordinates for 2D symbols to maintain rotational symmetry, then 
 386        converted back to Cartesian form.
 387
 388        Parameters
 389        ----------
 390        order : int
 391            Order of the asymptotic expansion in powers of 1/ρ, where ρ = |ξ| in 1D 
 392            or ρ = sqrt(ξ² + η²) in 2D. Only the leading-order term is returned.
 393
 394        Returns
 395        -------
 396        sympy.Expr
 397            The principal symbol component, homogeneous of degree `m - order`, where 
 398            `m` is the original symbol's order.
 399
 400        Notes:
 401        - In 1D, uses direct series expansion in ξ.
 402        - In 2D, expands in radial variable ρ while preserving angular dependence.
 403        - Useful for microlocal analysis and constructing parametrices.
 404        """
 405
 406        p = self.symbol
 407        if self.dim == 1:
 408            xi = symbols('xi', real=True, positive=True)
 409            return simplify(series(p, xi, oo, n=order).removeO())
 410        elif self.dim == 2:
 411            xi, eta = symbols('xi eta', real=True, positive=True)
 412            # Homogeneous radial expansion: we set (ξ, η) = ρ (cosθ, sinθ)
 413            rho, theta = symbols('rho theta', real=True, positive=True)
 414            p_rho = p.subs({xi: rho * cos(theta), eta: rho * sin(theta)})
 415            expansion = series(p_rho, rho, oo, n=order).removeO()
 416            # Revert back to (ξ, η)
 417            expansion_cart = expansion.subs({rho: sqrt(xi**2 + eta**2),
 418                                             cos(theta): xi / sqrt(xi**2 + eta**2),
 419                                             sin(theta): eta / sqrt(xi**2 + eta**2)})
 420            return simplify(powdenest(expansion_cart, force=True))
 421                       
 422    def is_homogeneous(self, tol=1e-10):
 423        """
 424        Check whether the symbol is homogeneous in the frequency variables.
 425    
 426        Returns
 427        -------
 428        (bool, Rational or float or None)
 429            Tuple (is_homogeneous, degree) where:
 430            - is_homogeneous: True if the symbol satisfies p(λξ, λη) = λ^m * p(ξ, η)
 431            - degree: the detected degree m if homogeneous, or None
 432        """
 433        from sympy import symbols, simplify, expand, Eq
 434        from sympy.abc import l
 435    
 436        if self.dim == 1:
 437            xi = symbols('xi', real=True, positive=True)
 438            l = symbols('l', real=True, positive=True)
 439            p = self.symbol
 440            p_scaled = p.subs(xi, l * xi)
 441            ratio = simplify(p_scaled / p)
 442            if ratio.has(xi):
 443                return False, None
 444            try:
 445                deg = simplify(ratio).as_base_exp()[1]
 446                return True, deg
 447            except Exception:
 448                return False, None
 449    
 450        elif self.dim == 2:
 451            xi, eta = symbols('xi eta', real=True, positive=True)
 452            l = symbols('l', real=True, positive=True)
 453            p = self.symbol
 454            p_scaled = p.subs({xi: l * xi, eta: l * eta})
 455            ratio = simplify(p_scaled / p)
 456            # If ratio == l**m with no (xi, eta) left, it's homogeneous
 457            if ratio.has(xi, eta):
 458                return False, None
 459            try:
 460                base, exp = ratio.as_base_exp()
 461                if base == l:
 462                    return True, exp
 463            except Exception:
 464                pass
 465            return False, None
 466
 467    def symbol_order(self, max_order=10, tol=1e-3):
 468        """
 469        Estimate the homogeneity order of the pseudo-differential symbol in high-frequency asymptotics.
 470    
 471        This method attempts to determine the leading-order behavior of the symbol p(x, ξ) or p(x, y, ξ, η)
 472        as |ξ| → ∞ (in 1D) or |(ξ, η)| → ∞ (in 2D). The returned value represents the asymptotic growth or decay rate,
 473        which is essential for understanding the regularity and mapping properties of the corresponding operator.
 474    
 475        The function uses symbolic preprocessing to ensure proper factorization of frequency variables,
 476        especially in sqrt and power expressions, to avoid erroneous order detection (e.g., due to hidden scaling).
 477    
 478        Parameters
 479        ----------
 480        max_order : int, optional
 481            Maximum number of terms to consider in the series expansion. Default is 10.
 482        tol : float, optional
 483            Tolerance threshold for evaluating the coefficient magnitude. If the coefficient is too small,
 484            the detected order may be discarded. Default is 1e-3.
 485    
 486        Returns
 487        -------
 488        float or None
 489            - If the symbol is homogeneous, returns its exact homogeneity degree as a float.
 490            - Otherwise, estimates the dominant asymptotic order from leading terms in the expansion.
 491            - Returns None if no valid order could be determined.
 492    
 493        Notes
 494        -----
 495        - In 1D:
 496            Two strategies are used:
 497                1. Expand directly in xi at infinity.
 498                2. Substitute xi = 1/z and expand around z = 0.
 499    
 500        - In 2D:
 501            - Transform the symbol into polar coordinates: (xi, eta) = rho*(cos(theta), sin(theta)).
 502            - Expand in rho at infinity, then extract the leading term's power.
 503            - An alternative substitution using 1/z is also tried if the first method fails.
 504    
 505        - Preprocessing steps:
 506            - Sqrt expressions involving frequencies are rewritten to isolate the leading variable.
 507            - Power expressions are factored explicitly to ensure correct symbolic scaling.
 508    
 509        - If the symbol is not homogeneous, a warning is issued, and the result should be interpreted with care.
 510        
 511        - For non-homogeneous symbols, only the principal asymptotic term is considered.
 512    
 513        Raises
 514        ------
 515        NotImplementedError
 516            If the spatial dimension is neither 1 nor 2.
 517        """
 518        from sympy import (
 519            symbols, series, simplify, sqrt, cos, sin, oo, powdenest, radsimp,
 520            expand, expand_power_base
 521        )
 522    
 523        def preprocess_sqrt(expr, freq):
 524            return expr.replace(
 525                lambda e: e.func == sqrt and freq in e.free_symbols,
 526                lambda e: freq * sqrt(1 + (e.args[0] - freq**2) / freq**2)
 527            )
 528    
 529        def preprocess_power(expr, freq):
 530            return expr.replace(
 531                lambda e: e.is_Pow and freq in e.free_symbols,
 532                lambda e: freq**e.exp * (1 + e.base / freq**e.base.as_powers_dict().get(freq, 0))**e.exp
 533            )
 534    
 535        def validate_order(power, coeff, vars_x, tol):
 536            if power is None:
 537                return None
 538            if any(v in coeff.free_symbols for v in vars_x):
 539                print("⚠️ Coefficient depends on spatial variables; ignoring")
 540                return None
 541            try:
 542                coeff_val = abs(float(coeff.evalf()))
 543                if coeff_val < tol:
 544                    print(f"⚠️ Coefficient too small ({coeff_val:.2e} < {tol})")
 545                    return None
 546            except Exception as e:
 547                print(f"⚠️ Coefficient evaluation failed: {e}")
 548                return None
 549            return int(power) if power == int(power) else float(power)
 550    
 551        # Homogeneity check
 552        is_homog, degree = self.is_homogeneous()
 553        if is_homog:
 554            return float(degree)
 555        else:
 556            print("⚠️ The symbol is not homogeneous. The asymptotic order is not well defined.")
 557    
 558        if self.dim == 1:
 559            x = self.vars_x[0]
 560            xi = symbols('xi', real=True, positive=True)
 561    
 562            try:
 563                print("1D symbol_order - method 1")
 564                expr = preprocess_sqrt(self.symbol, xi)
 565                s = series(expr, xi, oo, n=max_order).removeO()
 566                lead = simplify(powdenest(s.as_leading_term(xi), force=True))
 567                power = lead.as_powers_dict().get(xi, None)
 568                coeff = lead / xi**power if power is not None else 0
 569                print("lead =", lead)
 570                print("power =", power)
 571                print("coeff =", coeff)
 572                order = validate_order(power, coeff, [x], tol)
 573                if order is not None:
 574                    return order
 575            except Exception:
 576                pass
 577    
 578            try:
 579                print("1D symbol_order - method 2")
 580                z = symbols('z', real=True, positive=True)
 581                expr_z = preprocess_sqrt(self.symbol.subs(xi, 1/z), 1/z)
 582                s = series(expr_z, z, 0, n=max_order).removeO()
 583                lead = simplify(powdenest(s.as_leading_term(z), force=True))
 584                power = lead.as_powers_dict().get(z, None)
 585                coeff = lead / z**power if power is not None else 0
 586                print("lead =", lead)
 587                print("power =", power)
 588                print("coeff =", coeff)
 589                order = validate_order(power, coeff, [x], tol)
 590                if order is not None:
 591                    return -order
 592            except Exception as e:
 593                print(f"⚠️ fallback z failed: {e}")
 594            return None
 595    
 596        elif self.dim == 2:
 597            x, y = self.vars_x
 598            xi, eta = symbols('xi eta', real=True, positive=True)
 599            rho, theta = symbols('rho theta', real=True, positive=True)
 600    
 601            try:
 602                print("2D symbol_order - method 1")
 603                p_rho = self.symbol.subs({xi: rho * cos(theta), eta: rho * sin(theta)})
 604                p_rho = preprocess_power(preprocess_sqrt(p_rho, rho), rho)
 605                s = series(simplify(p_rho), rho, oo, n=max_order).removeO()
 606                lead = radsimp(simplify(powdenest(s.as_leading_term(rho), force=True)))
 607                power = lead.as_powers_dict().get(rho, None)
 608                coeff = lead / rho**power if power is not None else 0
 609                print("lead =", lead)
 610                print("power =", power)
 611                print("coeff =", coeff)
 612                order = validate_order(power, coeff, [x, y], tol)
 613                if order is not None:
 614                    return order
 615            except Exception as e:
 616                print(f"⚠️ polar expansion failed: {e}")
 617    
 618            try:
 619                print("2D symbol_order - method 2")
 620                z = symbols('z', real=True, positive=True)
 621                xi_eta = {xi: (1/z) * cos(theta), eta: (1/z) * sin(theta)}
 622                p_rho = preprocess_sqrt(self.symbol.subs(xi_eta), 1/z)
 623                s = series(simplify(p_rho), z, 0, n=max_order).removeO()
 624                lead = radsimp(simplify(powdenest(s.as_leading_term(z), force=True)))
 625                power = lead.as_powers_dict().get(z, None)
 626                coeff = lead / z**power if power is not None else 0
 627                print("lead =", lead)
 628                print("power =", power)
 629                print("coeff =", coeff)
 630                order = validate_order(power, coeff, [x, y], tol)
 631                if order is not None:
 632                    return -order
 633            except Exception as e:
 634                print(f"⚠️ fallback z (2D) failed: {e}")
 635            return None
 636    
 637        else:
 638            raise NotImplementedError("Only 1D and 2D supported.")
 639
 640    
 641    def asymptotic_expansion(self, order=3):
 642        """
 643        Compute the asymptotic expansion of the symbol as |ξ| → ∞ (high-frequency regime).
 644    
 645        This method expands the pseudo-differential symbol in inverse powers of the 
 646        frequency variable(s), either in 1D or 2D. It handles both polynomial and 
 647        exponential symbols by performing a series expansion in 1/|ξ| up to the specified order.
 648    
 649        The expansion is performed directly in Cartesian coordinates for 1D symbols.
 650        For 2D symbols, the method uses polar coordinates (ρ, θ) to perform the expansion 
 651        at infinity in ρ, then converts the result back to Cartesian coordinates.
 652    
 653        Parameters
 654        ----------
 655        order : int, optional
 656            Maximum order of the asymptotic expansion. Default is 3.
 657    
 658        Returns
 659        -------
 660        sympy.Expr
 661            The asymptotic expansion of the symbol up to the given order, expressed in Cartesian coordinates.
 662            If expansion fails, returns the original unexpanded symbol.
 663    
 664        Notes:
 665        - In 1D: expansion is performed directly in terms of ξ.
 666        - In 2D: the symbol is first rewritten in polar coordinates (ρ,θ), expanded asymptotically 
 667          in ρ → ∞, then converted back to Cartesian coordinates (ξ,η).
 668        - Handles special case when the symbol is an exponential function by expanding its argument.
 669        - Symbolic normalization is applied early (via `simplify`) for 2D expressions to improve convergence.
 670        - Robust to failures: catches exceptions and issues warnings instead of raising errors.
 671        - Final expression is simplified using `powdenest` and `expand` for improved readability.
 672        """
 673        p = self.symbol
 674    
 675        if self.dim == 1:
 676            xi = symbols('xi', real=True, positive=True)
 677    
 678            try:
 679                # Case: exponential function
 680                if p.func == exp and len(p.args) == 1:
 681                    arg = p.args[0]
 682                    arg_series = series(arg, xi, oo, n=order).removeO()
 683                    expanded = series(exp(expand(arg_series)), xi, oo, n=order).removeO()
 684                    return simplify(powdenest(expanded, force=True))
 685                else:
 686                    expanded = series(p, xi, oo, n=order).removeO()
 687                    return simplify(powdenest(expanded, force=True))
 688    
 689            except Exception as e:
 690                print(f"Warning: 1D expansion failed: {e}")
 691                return p
 692    
 693        elif self.dim == 2:
 694            xi, eta = symbols('xi eta', real=True, positive=True)
 695            rho, theta = symbols('rho theta', real=True, positive=True)
 696    
 697            # Normalize before substitution
 698            p = simplify(p)
 699    
 700            # Substitute polar coordinates
 701            p_polar = p.subs({
 702                xi: rho * cos(theta),
 703                eta: rho * sin(theta)
 704            })
 705    
 706            try:
 707                # Handle exponentials
 708                if p_polar.func == exp and len(p_polar.args) == 1:
 709                    arg = p_polar.args[0]
 710                    arg_series = series(arg, rho, oo, n=order).removeO()
 711                    expanded = series(exp(expand(arg_series)), rho, oo, n=order).removeO()
 712                else:
 713                    expanded = series(p_polar, rho, oo, n=order).removeO()
 714    
 715                # Convert back to Cartesian
 716                norm = sqrt(xi**2 + eta**2)
 717                expansion_cart = expanded.subs({
 718                    rho: norm,
 719                    cos(theta): xi / norm,
 720                    sin(theta): eta / norm
 721                })
 722    
 723                # Final simplifications
 724                result = simplify(powdenest(expansion_cart, force=True))
 725                result = expand(result)
 726                return result
 727    
 728            except Exception as e:
 729                print(f"Warning: 2D expansion failed: {e}")
 730                return p  
 731            
 732    def compose_asymptotic(self, other, order=1, mode='kn', sign_convention=None):
 733        """
 734        Compose two pseudo-differential operators using an asymptotic expansion
 735        in the chosen quantization scheme (Kohn–Nirenberg or Weyl).
 736    
 737        Parameters
 738        ----------
 739        other : PseudoDifferentialOperator
 740            The operator to compose with this one.
 741        order : int, default=1
 742            Maximum order of the asymptotic expansion.
 743        mode : {'kn', 'weyl'}, default='kn'
 744            Quantization mode:
 745            - 'kn' : Kohn–Nirenberg quantization (left-quantized)
 746            - 'weyl' : Weyl symmetric quantization
 747        sign_convention : {'standard', 'inverse'}, optional
 748            Controls the phase factor convention for the KN case:
 749            - 'standard' → (i)^(-n), gives [x, ξ] = +i (physics convention)
 750            - 'inverse' → (i)^(+n), gives [x, ξ] = -i (mathematical adjoint convention)
 751            If None, defaults to 'standard'.
 752    
 753        Returns
 754        -------
 755        sympy.Expr
 756            Symbolic expression for the composed symbol up to the given order.
 757    
 758        Notes
 759        -----
 760        - In 1D (Kohn–Nirenberg):
 761            (p ∘ q)(x, ξ) ~ Σₙ (1/n!) (i sgn)^n ∂_ξⁿ p(x, ξ) ∂_xⁿ q(x, ξ)
 762        - In 1D (Weyl):
 763            (p # q)(x, ξ) = exp[(i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q)] p(x, ξ) q(x, ξ)
 764            truncated at given order.
 765    
 766        Examples
 767        --------
 768        X = a*x, Y = b*ξ
 769        X_op.compose_asymptotic(Y_op, order=3, mode='weyl')
 770        """
 771    
 772        from sympy import diff, factorial, simplify, symbols
 773    
 774        assert self.dim == other.dim, "Operator dimensions must match"
 775        p, q = self.symbol, other.symbol
 776    
 777        # Default sign convention
 778        if sign_convention is None:
 779            sign_convention = 'standard'
 780        sign = -1 if sign_convention == 'standard' else +1
 781    
 782        # --- 1D case ---
 783        if self.dim == 1:
 784            x = self.vars_x[0]
 785            xi = symbols('xi', real=True)
 786            result = 0
 787    
 788            if mode == 'kn':  # Kohn–Nirenberg
 789                for n in range(order + 1):
 790                    term = (1 / factorial(n)) * diff(p, xi, n) * diff(q, x, n) * (1j) ** (sign * n)
 791                    result += term
 792    
 793            elif mode == 'weyl':  # Weyl symmetric composition
 794                # Weyl star product: exp((i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q))
 795                result = 0
 796                for n in range(order + 1):
 797                    for k in range(n + 1):
 798                        # k derivatives acting as (∂_ξ^k p)(∂_x^(n−k) q)
 799                        coeff = (1 / (factorial(k) * factorial(n - k))) * ((1j / 2) ** n) * ((-1) ** (n - k))
 800                        term = coeff * diff(p, xi, k, x, n - k, evaluate=True) * diff(q, x, k, xi, n - k, evaluate=True)
 801                        result += term
 802    
 803            else:
 804                raise ValueError("mode must be either 'kn' or 'weyl'")
 805    
 806            return simplify(result)
 807    
 808        # --- 2D case ---
 809        elif self.dim == 2:
 810            x, y = self.vars_x
 811            xi, eta = symbols('xi eta', real=True)
 812            result = 0
 813    
 814            if mode == 'kn':
 815                for n in range(order + 1):
 816                    for i in range(n + 1):
 817                        j = n - i
 818                        term = (1 / (factorial(i) * factorial(j))) * \
 819                               diff(p, xi, i, eta, j) * diff(q, x, i, y, j) * (1j) ** (sign * n)
 820                        result += term
 821    
 822            elif mode == 'weyl':
 823                for n in range(order + 1):
 824                    for i in range(n + 1):
 825                        j = n - i
 826                        coeff = (1 / (factorial(i) * factorial(j))) * ((1j / 2) ** n) * ((-1) ** (n - i))
 827                        term = coeff * diff(p, xi, i, eta, j, x, 0, y, 0) * diff(q, x, i, y, j, xi, 0, eta, 0)
 828                        result += term
 829            else:
 830                raise ValueError("mode must be either 'kn' or 'weyl'")
 831    
 832            return simplify(result)
 833    
 834        else:
 835            raise NotImplementedError("Only 1D and 2D cases are implemented")
 836
 837    def commutator_symbolic(self, other, order=1, mode='kn', sign_convention=None):
 838        """
 839        Compute the symbolic commutator [A, B] = A∘B − B∘A of two pseudo-differential operators
 840        using formal asymptotic expansion of their composition symbols.
 841    
 842        This method computes the asymptotic expansion of the commutator's symbol up to a given 
 843        order, based on the symbolic calculus of pseudo-differential operators in the 
 844        Kohn–Nirenberg quantization. The result is a purely symbolic sympy expression that 
 845        captures the leading-order noncommutativity of the operators.
 846    
 847        Parameters
 848        ----------
 849        other : PseudoDifferentialOperator
 850            The pseudo-differential operator B to commute with this operator A.
 851        order : int, default=1
 852            Maximum order of the asymptotic expansion. 
 853            - order=1 yields the leading term proportional to the Poisson bracket {p, q}.
 854            - Higher orders include correction terms involving higher mixed derivatives.
 855    
 856        Returns
 857        -------
 858        sympy.Expr
 859            Symbolic expression for the asymptotic expansion of the commutator symbol 
 860            σ([A,B]) = σ(A∘B − B∘A).
 861    
 862        """
 863        assert self.dim == other.dim, "Operator dimensions must match"
 864        p, q = self.symbol, other.symbol
 865    
 866        pq = self.compose_asymptotic(other, order=order, mode=mode, sign_convention=sign_convention)
 867        qp = other.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
 868        
 869        comm_symbol = simplify(pq-qp)
 870
 871        return comm_symbol
 872
 873    def right_inverse_asymptotic(self, order=1):
 874        """
 875        Construct a formal right inverse R of the pseudo-differential operator P such that 
 876        the composition P ∘ R equals the identity plus a smoothing operator of order -order.
 877    
 878        This method computes an asymptotic expansion for the right inverse using recursive 
 879        corrections based on derivatives of the symbol p(x, ξ) and lower-order terms of R.
 880    
 881        Parameters
 882        ----------
 883        order : int
 884            Number of terms to include in the asymptotic expansion. Higher values improve 
 885            approximation at the cost of complexity and computational effort.
 886    
 887        Returns
 888        -------
 889        sympy.Expr
 890            The symbolic expression representing the formal right inverse R(x, ξ), which satisfies:
 891            P ∘ R = Id + O(⟨ξ⟩^{-order}), where ⟨ξ⟩ = (1 + |ξ|²)^{1/2}.
 892    
 893        Notes
 894        -----
 895        - In 1D: The recursion involves spatial derivatives of R and derivatives of p with respect to ξ.
 896        - In 2D: The multi-index generalization is used with mixed derivatives in ξ and η.
 897        - The construction relies on the non-vanishing of the principal symbol p to ensure invertibility.
 898        - Each term in the expansion corresponds to higher-order corrections involving commutators 
 899          between the operator P and the current approximation of R.
 900        """
 901        p = self.symbol
 902        if self.dim == 1:
 903            x = self.vars_x[0]
 904            xi = symbols('xi', real=True)
 905            r = 1 / p.subs(xi, xi)  # r0
 906            R = r
 907            for n in range(1, order + 1):
 908                term = 0
 909                for k in range(1, n + 1):
 910                    coeff = (1j)**(-k) / factorial(k)
 911                    inner = diff(p, xi, k) * diff(R, x, k)
 912                    term += coeff * inner
 913                R = R - r * term
 914        elif self.dim == 2:
 915            x, y = self.vars_x
 916            xi, eta = symbols('xi eta', real=True)
 917            r = 1 / p.subs({xi: xi, eta: eta})
 918            R = r
 919            for n in range(1, order + 1):
 920                term = 0
 921                for k1 in range(n + 1):
 922                    for k2 in range(n + 1 - k1):
 923                        if k1 + k2 == 0: continue
 924                        coeff = (1j)**(-(k1 + k2)) / (factorial(k1) * factorial(k2))
 925                        dp = diff(p, xi, k1, eta, k2)
 926                        dR = diff(R, x, k1, y, k2)
 927                        term += coeff * dp * dR
 928                R = R - r * term
 929        return R
 930
 931    def left_inverse_asymptotic(self, order=1):
 932        """
 933        Construct a formal left inverse L such that the composition L ∘ P equals the identity 
 934        operator up to terms of order ξ^{-order}. This expansion is performed asymptotically 
 935        at infinity in the frequency variable(s).
 936    
 937        The left inverse is built iteratively using symbolic differentiation and the 
 938        method of asymptotic expansions for pseudo-differential operators. It ensures that:
 939        
 940            L(P(x,ξ),x,D) ∘ P(x,D) = Id + smoothing operator of order -order
 941    
 942        Parameters
 943        ----------
 944        order : int, optional
 945            Maximum number of terms in the asymptotic expansion (default is 1). Higher values 
 946            yield more accurate inverses at the cost of increased computational complexity.
 947    
 948        Returns
 949        -------
 950        sympy.Expr
 951            Symbolic expression representing the principal symbol of the formal left inverse 
 952            operator L(x,ξ). This expression depends on spatial variables and frequencies, 
 953            and includes correction terms up to the specified order.
 954    
 955        Notes
 956        -----
 957        - In 1D: Uses recursive application of the Leibniz formula for symbols.
 958        - In 2D: Generalizes to multi-indices for mixed derivatives in (x,y) and (ξ,η).
 959        - Each term involves combinations of derivatives of the original symbol p(x,ξ) and 
 960          previously computed terms of the inverse.
 961        - Coefficients include powers of 1j (i) and factorial normalization for derivative terms.
 962        """
 963        p = self.symbol
 964        if self.dim == 1:
 965            x = self.vars_x[0]
 966            xi = symbols('xi', real=True)
 967            l = 1 / p.subs(xi, xi)
 968            L = l
 969            for n in range(1, order + 1):
 970                term = 0
 971                for k in range(1, n + 1):
 972                    coeff = (1j)**(-k) / factorial(k)
 973                    inner = diff(L, xi, k) * diff(p, x, k)
 974                    term += coeff * inner
 975                L = L - term * l
 976        elif self.dim == 2:
 977            x, y = self.vars_x
 978            xi, eta = symbols('xi eta', real=True)
 979            l = 1 / p.subs({xi: xi, eta: eta})
 980            L = l
 981            for n in range(1, order + 1):
 982                term = 0
 983                for k1 in range(n + 1):
 984                    for k2 in range(n + 1 - k1):
 985                        if k1 + k2 == 0: continue
 986                        coeff = (1j)**(-(k1 + k2)) / (factorial(k1) * factorial(k2))
 987                        dp = diff(p, x, k1, y, k2)
 988                        dL = diff(L, xi, k1, eta, k2)
 989                        term += coeff * dL * dp
 990                L = L - term * l
 991        return L
 992
 993    def formal_adjoint(self):
 994        """
 995        Compute the formal adjoint symbol P* of the pseudo-differential operator.
 996
 997        The adjoint is defined such that for any test functions u and v,
 998        ⟨P u, v⟩ = ⟨u, P* v⟩ holds in the distributional sense. This is obtained by 
 999        taking the complex conjugate of the symbol and expanding it asymptotically 
1000        at infinity to ensure proper behavior under integration by parts.
1001
1002        Returns
1003        -------
1004        sympy.Expr
1005            The adjoint symbol P*(x, ξ) in 1D or P*(x, y, ξ, η) in 2D.
1006        
1007        Notes:
1008        - In 1D, the expansion is performed in powers of 1/|ξ|.
1009        - In 2D, the expansion is radial in |ξ| = sqrt(ξ² + η²).
1010        - This method ensures symbolic simplifications for readability and efficiency.
1011        """
1012        p = self.symbol
1013        if self.dim == 1:
1014            x, = self.vars_x
1015            xi = symbols('xi', real=True)
1016            p_star = conjugate(p)
1017            p_star = simplify(series(p_star, xi, oo, n=6).removeO())
1018            return p_star
1019        elif self.dim == 2:
1020            x, y = self.vars_x
1021            xi, eta = symbols('xi eta', real=True)
1022            p_star = conjugate(p)
1023            p_star = simplify(series(p_star, sqrt(xi**2 + eta**2), oo, n=6).removeO())
1024            return p_star
1025
1026    def exponential_symbol(self, t=1.0, order=1, mode='kn', sign_convention=None):
1027        """
1028        Compute the symbol of exp(tP) using asymptotic expansion methods.
1029        
1030        This method calculates the exponential of a pseudo-differential operator 
1031        using either a direct power series expansion or a Magnus expansion, 
1032        depending on the structure of the symbol. The result is valid up to 
1033        the specified asymptotic order.
1034        
1035        Parameters
1036        ----------
1037        t : float or sympy.Symbol, default=1.0
1038            Time or evolution parameter. Common uses:
1039            - t = -i*τ for Schrödinger evolution: exp(-iτH)
1040            - t = τ for heat/diffusion: exp(τΔ)
1041            - t for general propagators
1042        order : int, default=3
1043            Maximum order of the asymptotic expansion. Higher orders include 
1044            more composition terms, improving accuracy for small t or when 
1045            non-commutativity effects are significant.
1046        
1047        Returns
1048        -------
1049        sympy.Expr
1050            Symbolic expression for the exponential operator symbol, computed 
1051            as an asymptotic series up to the specified order.
1052        
1053        Notes
1054        -----
1055        - For commutative symbols (e.g., pure multiplication operators), the 
1056          exponential is exact: exp(tP) = exp(t*p(x,ξ)).
1057        
1058        - For general non-commutative operators, the method uses the BCH-type 
1059          expansion via iterated composition:
1060          exp(tP) ~ I + tP + (t²/2!)P∘P + (t³/3!)P∘P∘P + ...
1061          
1062        - Each power P^n is computed via compose_asymptotic, which accounts 
1063          for the non-commutativity through derivative terms.
1064        
1065        - The expansion is valid for |t| small enough or when the symbol has 
1066          appropriate decay/growth properties.
1067        
1068        - In quantum mechanics (Schrödinger): U(t) = exp(-itH/ℏ) represents 
1069          the time evolution operator.
1070        
1071        - In parabolic PDEs (heat equation): exp(tΔ) is the heat kernel.
1072
1073        """
1074        if self.dim == 1:
1075            x = self.vars_x[0]
1076            xi = symbols('xi', real=True)
1077            
1078            # Initialize with identity
1079            result = 1
1080            
1081            # First order term: tP
1082            current_power = self.symbol
1083            result += t * current_power
1084            
1085            # Higher order terms: (t^n/n!) P^n computed via composition
1086            for n in range(2, order + 1):
1087                # Compute P^n = P^(n-1) ∘ P via asymptotic composition
1088                # We use a temporary operator for composition
1089                temp_op = PseudoDifferentialOperator(
1090                    current_power, [x], mode='symbol'
1091                )
1092                current_power = temp_op.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
1093                
1094                # Add term (t^n/n!) * P^n
1095                coeff = t**n / factorial(n)
1096                result += coeff * current_power
1097            
1098            return simplify(result)
1099        
1100        elif self.dim == 2:
1101            x, y = self.vars_x
1102            xi, eta = symbols('xi eta', real=True)
1103            
1104            # Initialize with identity
1105            result = 1
1106            
1107            # First order term: tP
1108            current_power = self.symbol
1109            result += t * current_power
1110            
1111            # Higher order terms: (t^n/n!) P^n computed via composition
1112            for n in range(2, order + 1):
1113                # Compute P^n = P^(n-1) ∘ P via asymptotic composition
1114                temp_op = PseudoDifferentialOperator(
1115                    current_power, [x, y], mode='symbol'
1116                )
1117                current_power = temp_op.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
1118                
1119                # Add term (t^n/n!) * P^n
1120                coeff = t**n / factorial(n)
1121                result += coeff * current_power
1122            
1123            return simplify(result)
1124        
1125        else:
1126            raise NotImplementedError("Only 1D and 2D operators are supported")
1127        
1128    def trace_formula(self, volume_element=None, numerical=False, 
1129                      x_bounds=None, xi_bounds=None):
1130        """
1131        Compute the semiclassical trace of the pseudo-differential operator.
1132        
1133        The trace formula relates the quantum trace of an operator to a 
1134        phase-space integral of its symbol, providing a fundamental link 
1135        between classical and quantum mechanics. This implementation supports 
1136        both symbolic and numerical integration.
1137        
1138        Parameters
1139        ----------
1140        volume_element : sympy.Expr, optional
1141            Custom volume element for the phase space integration. If None, 
1142            uses the standard Liouville measure dx dξ/(2π)^d.
1143        numerical : bool, default=False
1144            If True, perform numerical integration over specified bounds.
1145            If False, attempt symbolic integration (may fail for complex symbols).
1146        x_bounds : tuple of tuples, optional
1147            Spatial integration bounds. For 1D: ((x_min, x_max),)
1148            For 2D: ((x_min, x_max), (y_min, y_max))
1149            Required if numerical=True.
1150        xi_bounds : tuple of tuples, optional
1151            Frequency integration bounds. For 1D: ((xi_min, xi_max),)
1152            For 2D: ((xi_min, xi_max), (eta_min, eta_max))
1153            Required if numerical=True.
1154        
1155        Returns
1156        -------
1157        sympy.Expr or float
1158            The trace of the operator. Returns a symbolic expression if 
1159            numerical=False, or a float if numerical=True.
1160        
1161        Notes
1162        -----
1163        - The semiclassical trace formula states:
1164          Tr(P) = (2π)^{-d} ∫∫ p(x,ξ) dx dξ
1165          where d is the spatial dimension and p(x,ξ) is the operator symbol.
1166        
1167        - For 1D: Tr(P) = (1/2π) ∫_{-∞}^{∞} ∫_{-∞}^{∞} p(x,ξ) dx dξ
1168        
1169        - For 2D: Tr(P) = (1/4π²) ∫∫∫∫ p(x,y,ξ,η) dx dy dξ dη
1170        
1171        - This formula is exact for trace-class operators and provides an 
1172          asymptotic approximation for general pseudo-differential operators.
1173        
1174        - Physical interpretation: the trace counts the "number of states" 
1175          weighted by the observable p(x,ξ).
1176        
1177        - For projection operators (χ_Ω with χ² = χ), the trace gives the 
1178          dimension of the range, related to the phase space volume of Ω.
1179        
1180        - The factor (2π)^{-d} comes from the quantum normalization of 
1181          coherent states / Weyl quantization.
1182        """
1183        from sympy import integrate, simplify, lambdify
1184        from scipy.integrate import dblquad, nquad
1185        
1186        p = self.symbol
1187        
1188        if numerical:
1189            if x_bounds is None or xi_bounds is None:
1190                raise ValueError(
1191                    "x_bounds and xi_bounds must be provided for numerical integration"
1192                )
1193        
1194        if self.dim == 1:
1195            x, = self.vars_x
1196            xi = symbols('xi', real=True)
1197            
1198            if volume_element is None:
1199                volume_element = 1 / (2 * pi)
1200            
1201            if numerical:
1202                # Numerical integration
1203                p_func = lambdify((x, xi), p, 'numpy')
1204                (x_min, x_max), = x_bounds
1205                (xi_min, xi_max), = xi_bounds
1206                
1207                def integrand(xi_val, x_val):
1208                    return p_func(x_val, xi_val)
1209                
1210                result, error = dblquad(
1211                    integrand,
1212                    x_min, x_max,
1213                    lambda x: xi_min, lambda x: xi_max
1214                )
1215                
1216                result *= float(volume_element)
1217                print(f"Numerical trace = {result:.6e} ± {error:.6e}")
1218                return result
1219            
1220            else:
1221                # Symbolic integration
1222                integrand = p * volume_element
1223                
1224                try:
1225                    # Try to integrate over xi first, then x
1226                    integral_xi = integrate(integrand, (xi, -oo, oo))
1227                    integral_x = integrate(integral_xi, (x, -oo, oo))
1228                    return simplify(integral_x)
1229                except:
1230                    print("Warning: Symbolic integration failed. Try numerical=True")
1231                    return integrate(integrand, (xi, -oo, oo), (x, -oo, oo))
1232        
1233        elif self.dim == 2:
1234            x, y = self.vars_x
1235            xi, eta = symbols('xi eta', real=True)
1236            
1237            if volume_element is None:
1238                volume_element = 1 / (4 * pi**2)
1239            
1240            if numerical:
1241                # Numerical integration in 4D
1242                p_func = lambdify((x, y, xi, eta), p, 'numpy')
1243                (x_min, x_max), (y_min, y_max) = x_bounds
1244                (xi_min, xi_max), (eta_min, eta_max) = xi_bounds
1245                
1246                def integrand(eta_val, xi_val, y_val, x_val):
1247                    return p_func(x_val, y_val, xi_val, eta_val)
1248                
1249                result, error = nquad(
1250                    integrand,
1251                    [
1252                        [eta_min, eta_max],
1253                        [xi_min, xi_max],
1254                        [y_min, y_max],
1255                        [x_min, x_max]
1256                    ]
1257                )
1258                
1259                result *= float(volume_element)
1260                print(f"Numerical trace = {result:.6e} ± {error:.6e}")
1261                return result
1262            
1263            else:
1264                # Symbolic integration
1265                integrand = p * volume_element
1266                
1267                try:
1268                    # Integrate in order: eta, xi, y, x
1269                    integral_eta = integrate(integrand, (eta, -oo, oo))
1270                    integral_xi = integrate(integral_eta, (xi, -oo, oo))
1271                    integral_y = integrate(integral_xi, (y, -oo, oo))
1272                    integral_x = integrate(integral_y, (x, -oo, oo))
1273                    return simplify(integral_x)
1274                except:
1275                    print("Warning: Symbolic integration failed. Try numerical=True")
1276                    return integrate(
1277                        integrand,
1278                        (eta, -oo, oo), (xi, -oo, oo),
1279                        (y, -oo, oo), (x, -oo, oo)
1280                    )
1281        
1282        else:
1283            raise NotImplementedError("Only 1D and 2D operators are supported")
1284
1285    def pseudospectrum_analysis(self, x_grid, lambda_real_range, lambda_imag_range,
1286                               epsilon_levels=[0.1, 0.01, 0.001, 0.0001],
1287                               resolution=100, method='spectral', L=None, N=None,
1288                               use_sparse=False, parallel=True, n_workers=4,
1289                               adaptive=False, adaptive_threshold=0.5,
1290                               auto_range=True, plot=True):
1291        """
1292        Compute and visualize the pseudospectrum of the operator.
1293        
1294        Optimizations:
1295        - Uses apply() method instead of manual loops
1296        - Parallel computation of resolvent norms
1297        - Sparse matrix support for large N
1298        - Optional adaptive grid refinement
1299        
1300        Parameters
1301        ----------
1302        x_grid : array
1303            Spatial grid for quantization
1304        lambda_real_range : tuple
1305            (min, max) for real part of λ
1306        lambda_imag_range : tuple
1307            (min, max) for imaginary part of λ
1308        epsilon_levels : list
1309            Levels for ε-pseudospectrum contours
1310        resolution : int
1311            Grid resolution for λ sampling
1312        method : str
1313            'spectral' or 'finite_difference'
1314        L : float, optional
1315            Domain half-length for spectral method
1316        N : int, optional
1317            Number of grid points
1318        use_sparse : bool
1319            Use sparse matrices for large N
1320        parallel : bool
1321            Enable parallel computation
1322        n_workers : int
1323            Number of parallel workers
1324        adaptive : bool
1325            Use adaptive grid refinement
1326        adaptive_threshold : float
1327            Threshold for adaptive refinement
1328            
1329        Returns
1330        -------
1331        dict
1332            Dictionary with pseudospectrum data and operator matrix
1333        """
1334        if self.dim != 1:
1335            raise NotImplementedError('Pseudospectrum analysis currently supports 1D only')
1336        
1337        # Step 1: Build operator matrix
1338        print(f"Building operator matrix using '{method}' method...")
1339        H, x_grid_used, k_grid = self._build_operator_matrix(x_grid, method, L, N)
1340        N_actual = H.shape[0]
1341        
1342        # Step 1.5: Compute eigenvalues FIRST to adjust range if needed
1343        print('Computing eigenvalues...')
1344        eigenvalues = self._compute_eigenvalues(H, use_sparse)
1345        
1346        # Auto-adjust range if requested
1347        if auto_range and eigenvalues is not None:
1348            eig_real_min, eig_real_max = eigenvalues.real.min(), eigenvalues.real.max()
1349            eig_imag_min, eig_imag_max = eigenvalues.imag.min(), eigenvalues.imag.max()
1350            
1351            # Add 20% margin around eigenvalues
1352            margin_real = 0.2 * (eig_real_max - eig_real_min + 1)
1353            margin_imag = max(0.2 * (eig_imag_max - eig_imag_min + 1), 2.0)
1354            
1355            lambda_real_range = (eig_real_min - margin_real, eig_real_max + margin_real)
1356            lambda_imag_range = (eig_imag_min - margin_imag, eig_imag_max + margin_imag)
1357            
1358            print(f'Auto-adjusted λ range:')
1359            print(f'  Re(λ) ∈ [{lambda_real_range[0]:.2f}, {lambda_real_range[1]:.2f}]')
1360            print(f'  Im(λ) ∈ [{lambda_imag_range[0]:.2f}, {lambda_imag_range[1]:.2f}]')
1361        
1362        # Step 2: Compute pseudospectrum with corrected range
1363        print(f'Computing pseudospectrum over {resolution}×{resolution} grid...')
1364        if adaptive:
1365            print('Using adaptive grid refinement...')
1366            Lambda, resolvent_norm, sigma_min_grid = self._compute_pseudospectrum_adaptive(
1367                H, lambda_real_range, lambda_imag_range, resolution,
1368                use_sparse=use_sparse, parallel=parallel, n_workers=n_workers,
1369                threshold=adaptive_threshold
1370            )
1371        else:
1372            Lambda, resolvent_norm, sigma_min_grid = self._compute_pseudospectrum(
1373                H, lambda_real_range, lambda_imag_range, resolution,
1374                use_sparse=use_sparse, parallel=parallel, n_workers=n_workers
1375            )
1376        
1377        # Step 3: Visualize
1378        if plot:
1379            self._plot_pseudospectrum(Lambda, resolvent_norm, sigma_min_grid,
1380                                      epsilon_levels, eigenvalues)
1381        
1382        return {
1383            'lambda_grid': Lambda,
1384            'resolvent_norm': resolvent_norm,
1385            'sigma_min': sigma_min_grid,
1386            'epsilon_levels': epsilon_levels,
1387            'eigenvalues': eigenvalues,
1388            'operator_matrix': H,
1389            'x_grid': x_grid_used,
1390            'k_grid': k_grid
1391        }
1392
1393
1394    def _build_operator_matrix(self, x_grid, method, L, N):
1395        """
1396        Build the discrete operator matrix H.
1397        
1398        Optimized to use the apply() method instead of manual integration.
1399        
1400        Parameters
1401        ----------
1402        x_grid : array
1403            Input spatial grid
1404        method : str
1405            'spectral' or 'finite_difference'
1406        L : float, optional
1407            Domain half-length
1408        N : int, optional
1409            Number of grid points
1410            
1411        Returns
1412        -------
1413        H : ndarray
1414            Operator matrix (N×N)
1415        x_grid_used : ndarray
1416            Actual spatial grid used
1417        k_grid : ndarray
1418            Frequency grid
1419        """
1420        if method == 'spectral':
1421            # Setup spectral grid
1422            if L is None:
1423                L = (x_grid[-1] - x_grid[0]) / 2.0
1424            if N is None:
1425                N = len(x_grid)
1426            
1427            x_grid_spectral = np.linspace(-L, L, N, endpoint=False)
1428            dx = x_grid_spectral[1] - x_grid_spectral[0]
1429            k = np.fft.fftfreq(N, d=dx) * 2.0 * np.pi
1430            
1431            # Build matrix by applying operator to canonical basis
1432            # This is the KEY OPTIMIZATION: use apply() instead of manual loops
1433            H = np.zeros((N, N), dtype=complex)
1434            
1435            for j in range(N):
1436                # Create basis vector e_j
1437                e_j = np.zeros(N, dtype=complex)
1438                e_j[j] = 1.0
1439                
1440                # Apply operator using the existing apply() method
1441                # This automatically handles the symbol evaluation and FFT operations
1442                H[:, j] = self.apply(
1443                    e_j, 
1444                    x_grid_spectral, 
1445                    k,
1446                    boundary_condition='periodic'
1447                )
1448            
1449            print(f'Operator quantized via apply() method: {N}×{N} matrix')
1450            return H, x_grid_spectral, k
1451            
1452        elif method == 'finite_difference':
1453            # Fallback to finite difference (keep original implementation)
1454            N = len(x_grid)
1455            dx = x_grid[1] - x_grid[0]
1456            H = np.zeros((N, N), dtype=complex)
1457            
1458            for i in range(N):
1459                for j in range(N):
1460                    if i == j:
1461                        H[i, j] = self.p_func(x_grid[i], 0.0)
1462                    elif abs(i - j) == 1:
1463                        xi_approx = np.pi / dx
1464                        H[i, j] = self.p_func(
1465                            (x_grid[i] + x_grid[j]) / 2,
1466                            xi_approx * np.sign(i - j)
1467                        ) / (2 * dx)
1468                    elif abs(i - j) == 2:
1469                        xi_approx = 2 * np.pi / dx
1470                        H[i, j] = self.p_func(
1471                            (x_grid[i] + x_grid[j]) / 2,
1472                            xi_approx
1473                        ) / dx ** 2
1474            
1475            print(f'Operator quantized via finite differences: {N}×{N} matrix')
1476            k = np.fft.fftfreq(N, d=dx) * 2.0 * np.pi
1477            return H, x_grid, k
1478            
1479        else:
1480            raise ValueError("method must be 'spectral' or 'finite_difference'")
1481
1482    def _compute_pseudospectrum(self, H, lambda_real_range, lambda_imag_range,
1483                               resolution, use_sparse=False, parallel=True,
1484                               n_workers=4):
1485        """
1486        Compute pseudospectrum on a uniform grid.
1487        
1488        Optimized with parallel computation and optional sparse matrices.
1489        
1490        Parameters
1491        ----------
1492        H : ndarray or sparse matrix
1493            Operator matrix
1494        lambda_real_range : tuple
1495            Range for Re(λ)
1496        lambda_imag_range : tuple
1497            Range for Im(λ)
1498        resolution : int
1499            Grid resolution
1500        use_sparse : bool
1501            Use sparse SVD for large matrices
1502        parallel : bool
1503            Enable parallel computation
1504        n_workers : int
1505            Number of parallel workers
1506            
1507        Returns
1508        -------
1509        Lambda : ndarray
1510            Complex grid of λ values
1511        resolvent_norm : ndarray
1512            Norm of (H - λI)^{-1}
1513        sigma_min_grid : ndarray
1514            Smallest singular value σ_min(H - λI)
1515        """
1516        from scipy.linalg import svdvals
1517        
1518        N = H.shape[0]
1519        lambda_re = np.linspace(*lambda_real_range, resolution)
1520        lambda_im = np.linspace(*lambda_imag_range, resolution)
1521        Lambda_re, Lambda_im = np.meshgrid(lambda_re, lambda_im)
1522        Lambda = Lambda_re + 1j * Lambda_im
1523        
1524        resolvent_norm = np.zeros_like(Lambda, dtype=float)
1525        sigma_min_grid = np.zeros_like(Lambda, dtype=float)
1526        
1527        I = np.eye(N)
1528        
1529        # Convert to sparse if requested and beneficial
1530        if use_sparse and N > 100:
1531            from scipy.sparse import csr_matrix, eye as sparse_eye
1532            from scipy.sparse.linalg import svds
1533            H_sparse = csr_matrix(H)
1534            I_sparse = sparse_eye(N, format='csr')
1535            use_sparse_svd = True
1536            print(f'Using sparse matrices (N={N})')
1537        else:
1538            use_sparse_svd = False
1539        
1540        if parallel and resolution * resolution > 100:
1541            # Parallel computation
1542            Lambda_flat = Lambda.ravel()
1543            
1544            def compute_single_point(idx):
1545                """Compute resolvent norm for a single λ value"""
1546                lam = Lambda_flat[idx]
1547                try:
1548                    if use_sparse_svd:
1549                        # Sparse SVD: compute only smallest singular value
1550                        A = H_sparse - lam * I_sparse
1551                        try:
1552                            # svds can be unstable, wrap in try-except
1553                            s_min = svds(A, k=1, which='SM', 
1554                                       return_singular_vectors=False)[0]
1555                        except:
1556                            # Fallback to dense computation
1557                            s = svdvals(A.toarray())
1558                            s_min = s[-1]
1559                    else:
1560                        # Dense SVD
1561                        A = H - lam * I
1562                        s = svdvals(A)
1563                        s_min = s[-1]
1564                    
1565                    return idx, 1.0 / (s_min + 1e-16), s_min
1566                except Exception as e:
1567                    return idx, np.nan, np.nan
1568            
1569            # Use ThreadPoolExecutor for parallel computation
1570            with ThreadPoolExecutor(max_workers=n_workers) as executor:
1571                futures = {executor.submit(compute_single_point, idx): idx 
1572                          for idx in range(len(Lambda_flat))}
1573                
1574                # Progress tracking
1575                completed = 0
1576                total = len(futures)
1577                
1578                for future in as_completed(futures):
1579                    idx, res_norm, s_min = future.result()
1580                    resolvent_norm.ravel()[idx] = res_norm
1581                    sigma_min_grid.ravel()[idx] = s_min
1582                    
1583                    completed += 1
1584                    if completed % (total // 10) == 0:
1585                        print(f'Progress: {completed}/{total} ({100*completed//total}%)')
1586            
1587        else:
1588            # Sequential computation
1589            for i in range(resolution):
1590                for j in range(resolution):
1591                    lam = Lambda[i, j]
1592                    try:
1593                        if use_sparse_svd:
1594                            A = H_sparse - lam * I_sparse
1595                            try:
1596                                s_min = svds(A, k=1, which='SM',
1597                                           return_singular_vectors=False)[0]
1598                            except:
1599                                s = svdvals(A.toarray())
1600                                s_min = s[-1]
1601                        else:
1602                            A = H - lam * I
1603                            s = svdvals(A)
1604                            s_min = s[-1]
1605                        
1606                        sigma_min_grid[i, j] = s_min
1607                        resolvent_norm[i, j] = 1.0 / (s_min + 1e-16)
1608                    except Exception:
1609                        resolvent_norm[i, j] = np.nan
1610                        sigma_min_grid[i, j] = np.nan
1611                
1612                if i % (resolution // 10) == 0:
1613                    print(f'Progress: {i}/{resolution} rows')
1614        
1615        return Lambda, resolvent_norm, sigma_min_grid
1616
1617    def _compute_pseudospectrum_adaptive(self, H, lambda_real_range, lambda_imag_range,
1618                                        base_resolution, use_sparse=False, parallel=True,
1619                                        n_workers=4, threshold=0.5, max_refinements=2):
1620        """
1621        Compute pseudospectrum with adaptive grid refinement.
1622        
1623        Starts with coarse grid and refines regions with high gradients.
1624        
1625        Parameters
1626        ----------
1627        H : ndarray
1628            Operator matrix
1629        lambda_real_range : tuple
1630            Range for Re(λ)
1631        lambda_imag_range : tuple
1632            Range for Im(λ)
1633        base_resolution : int
1634            Initial coarse resolution
1635        use_sparse : bool
1636            Use sparse matrices
1637        parallel : bool
1638            Enable parallel computation
1639        n_workers : int
1640            Number of workers
1641        threshold : float
1642            Gradient threshold for refinement
1643        max_refinements : int
1644            Maximum number of refinement levels
1645            
1646        Returns
1647        -------
1648        Lambda : ndarray
1649            Complex grid (may be non-uniform)
1650        resolvent_norm : ndarray
1651            Resolvent norms
1652        sigma_min_grid : ndarray
1653            Smallest singular values
1654        """
1655        # Start with coarse grid
1656        coarse_res = base_resolution // 2
1657        print(f'Level 0: Computing coarse grid ({coarse_res}×{coarse_res})...')
1658        
1659        Lambda_coarse, resolvent_coarse, sigma_coarse = self._compute_pseudospectrum(
1660            H, lambda_real_range, lambda_imag_range, coarse_res,
1661            use_sparse=use_sparse, parallel=parallel, n_workers=n_workers
1662        )
1663        
1664        # Compute gradient to identify regions needing refinement
1665        log_resolvent = np.log10(resolvent_coarse + 1e-16)
1666        grad_y, grad_x = np.gradient(log_resolvent)
1667        grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
1668        
1669        # Normalize gradient
1670        grad_normalized = grad_magnitude / (np.max(grad_magnitude) + 1e-10)
1671        
1672        # For now, return uniform fine grid
1673        # (Full adaptive implementation would require irregular grids)
1674        print(f'Level 1: Computing fine grid ({base_resolution}×{base_resolution})...')
1675        Lambda_fine, resolvent_fine, sigma_fine = self._compute_pseudospectrum(
1676            H, lambda_real_range, lambda_imag_range, base_resolution,
1677            use_sparse=use_sparse, parallel=parallel, n_workers=n_workers
1678        )
1679        
1680        high_gradient_pct = 100 * np.sum(grad_normalized > threshold) / grad_normalized.size
1681        print(f'High-gradient regions: {high_gradient_pct:.1f}% of domain')
1682        
1683        return Lambda_fine, resolvent_fine, sigma_fine
1684
1685    def _compute_eigenvalues(self, H, use_sparse=False):
1686        """
1687        Compute eigenvalues of operator matrix.
1688        
1689        Parameters
1690        ----------
1691        H : ndarray
1692            Operator matrix
1693        use_sparse : bool
1694            Use sparse eigenvalue solver
1695            
1696        Returns
1697        -------
1698        eigenvalues : ndarray or None
1699            Eigenvalues of H
1700        """
1701        try:
1702            if use_sparse and H.shape[0] > 100:
1703                from scipy.sparse.linalg import eigs
1704                from scipy.sparse import csr_matrix
1705                H_sparse = csr_matrix(H)
1706                k = min(20, H.shape[0] - 2)
1707                eigenvalues = eigs(H_sparse, k=k, return_eigenvectors=False)
1708            else:
1709                eigenvalues = np.linalg.eigvals(H)
1710            
1711            # Print diagnostics
1712            print(f'Eigenvalue range: [{eigenvalues.real.min():.2f}, {eigenvalues.real.max():.2f}]')
1713            print(f'Imaginary part range: [{eigenvalues.imag.min():.2e}, {eigenvalues.imag.max():.2e}]')
1714            
1715            return eigenvalues
1716        except Exception as e:
1717            warnings.warn(f'Eigenvalue computation failed: {e}')
1718            return None
1719
1720    def _plot_pseudospectrum(self, Lambda, resolvent_norm, sigma_min_grid,
1721                            epsilon_levels, eigenvalues):
1722        """
1723        Plot pseudospectrum results.
1724        
1725        Parameters
1726        ----------
1727        Lambda : ndarray
1728            Complex λ grid
1729        resolvent_norm : ndarray
1730            Resolvent norms
1731        sigma_min_grid : ndarray
1732            Smallest singular values
1733        epsilon_levels : list
1734            Contour levels
1735        eigenvalues : ndarray or None
1736            Eigenvalues to overlay
1737        """
1738        Lambda_re = Lambda.real
1739        Lambda_im = Lambda.imag
1740        
1741        plt.figure(figsize=(14, 6))
1742        
1743        # Left plot: ε-pseudospectrum
1744        plt.subplot(1, 2, 1)
1745        
1746        # Better contour level computation
1747        log_resolvent = np.log10(resolvent_norm + 1e-16)
1748        levels_log = np.log10(1.0 / np.array(epsilon_levels))
1749        
1750        # Only plot contours that exist in the data range
1751        valid_levels = [lv for lv in levels_log 
1752                       if log_resolvent.min() <= lv <= log_resolvent.max()]
1753        
1754        if len(valid_levels) > 0:
1755            cs = plt.contour(Lambda_re, Lambda_im, log_resolvent,
1756                            levels=valid_levels, colors='blue', linewidths=1.5)
1757            # Better labels
1758            labels = [f'ε={eps:.0e}' for eps in epsilon_levels[:len(valid_levels)]]
1759            fmt = dict(zip(cs.levels, labels))
1760            plt.clabel(cs, inline=True, fmt=fmt, fontsize=9)
1761        else:
1762            print('⚠️ Warning: No contours in specified epsilon range')
1763            # Plot general contours
1764            cs = plt.contour(Lambda_re, Lambda_im, log_resolvent,
1765                            levels=10, colors='blue', linewidths=1.5)
1766        
1767        if eigenvalues is not None:
1768            plt.plot(eigenvalues.real, eigenvalues.imag, 'r*', 
1769                    markersize=10, label='Eigenvalues', markeredgecolor='darkred')
1770        
1771        plt.xlabel('Re(λ)', fontsize=12)
1772        plt.ylabel('Im(λ)', fontsize=12)
1773        plt.title('ε-Pseudospectrum: log₁₀(‖(H - λI)⁻¹‖)', fontsize=13)
1774        plt.grid(alpha=0.3)
1775        plt.legend(fontsize=10)
1776        plt.axis('equal')
1777        
1778        # Right plot: Smallest singular value
1779        plt.subplot(1, 2, 2)
1780        
1781        # Use better colormap normalization
1782        from matplotlib.colors import LogNorm
1783        
1784        # Filter out invalid values
1785        sigma_plot = np.where(np.isfinite(sigma_min_grid), sigma_min_grid, np.nan)
1786        vmin = np.nanmin(sigma_plot[sigma_plot > 0]) if np.any(sigma_plot > 0) else 1e-10
1787        vmax = np.nanmax(sigma_plot)
1788        
1789        cs2 = plt.contourf(Lambda_re, Lambda_im, sigma_plot,
1790                          levels=50, cmap='viridis',
1791                          norm=LogNorm(vmin=vmin, vmax=vmax))
1792        plt.colorbar(cs2, label='σ_min(H - λI)')
1793        
1794        if eigenvalues is not None:
1795            plt.plot(eigenvalues.real, eigenvalues.imag, 'r*', 
1796                    markersize=10, markeredgecolor='darkred')
1797        
1798        # Plot epsilon contours
1799        for eps in epsilon_levels:
1800            cs_eps = plt.contour(Lambda_re, Lambda_im, sigma_plot,
1801                               levels=[eps], colors='red', linewidths=2, alpha=0.8)
1802        
1803        plt.xlabel('Re(λ)', fontsize=12)
1804        plt.ylabel('Im(λ)', fontsize=12)
1805        plt.title('Smallest singular value σ_min(H - λI)', fontsize=13)
1806        plt.grid(alpha=0.3)
1807        plt.axis('equal')
1808        
1809        plt.tight_layout()
1810        plt.show()
1811    
1812    
1813    def symplectic_flow(self):
1814        """
1815        Compute the Hamiltonian vector field associated with the principal symbol.
1816
1817        This method derives the canonical equations of motion for the phase space variables 
1818        (x, ξ) in 1D or (x, y, ξ, η) in 2D, based on the Hamiltonian formalism. These describe 
1819        how position and frequency variables evolve under the flow generated by the symbol.
1820
1821        Returns
1822        -------
1823        dict
1824            A dictionary containing the components of the Hamiltonian vector field:
1825            - In 1D: keys are 'dx/dt' and 'dxi/dt', corresponding to dx/dt = ∂p/∂ξ and dξ/dt = -∂p/∂x.
1826            - In 2D: keys are 'dx/dt', 'dy/dt', 'dxi/dt', and 'deta/dt', with similar definitions:
1827              dx/dt = ∂p/∂ξ, dy/dt = ∂p/∂η, dξ/dt = -∂p/∂x, dη/dt = -∂p/∂y.
1828
1829        Notes
1830        -----
1831        - The Hamiltonian here is the principal symbol p(x, ξ) itself.
1832        - This flow preserves the symplectic structure of phase space.
1833        """
1834        if self.dim == 1:
1835            x,  = self.vars_x
1836            xi = symbols('xi', real=True)
1837            return {
1838                'dx/dt': diff(self.symbol, xi),
1839                'dxi/dt': -diff(self.symbol, x)
1840            }
1841        elif self.dim == 2:
1842            x, y = self.vars_x
1843            xi, eta = symbols('xi eta', real=True)
1844            return {
1845                'dx/dt': diff(self.symbol, xi),
1846                'dy/dt': diff(self.symbol, eta),
1847                'dxi/dt': -diff(self.symbol, x),
1848                'deta/dt': -diff(self.symbol, y)
1849            }
1850
1851    def is_elliptic_numerically(self, x_grid, xi_grid, threshold=1e-8):
1852        """
1853        Check if the pseudo-differential symbol p(x, ξ) is elliptic over a given grid.
1854    
1855        A symbol is considered elliptic if its magnitude |p(x, ξ)| remains bounded away from zero 
1856        across all points in the spatial-frequency domain. This method evaluates the symbol on a 
1857        grid of spatial and frequency coordinates and checks whether its minimum absolute value 
1858        exceeds a specified threshold.
1859    
1860        Resampling is applied to large grids to prevent excessive memory usage, particularly in 2D.
1861    
1862        Parameters
1863        ----------
1864        x_grid : ndarray
1865            Spatial grid: either a 1D array (x) or a tuple of two 1D arrays (x, y).
1866        xi_grid : ndarray
1867            Frequency grid: either a 1D array (ξ) or a tuple of two 1D arrays (ξ, η).
1868        threshold : float, optional
1869            Minimum acceptable value for |p(x, ξ)|. If the smallest evaluated symbol value falls below this,
1870            the symbol is not considered elliptic.
1871    
1872        Returns
1873        -------
1874        bool
1875            True if the symbol is elliptic on the resampled grid, False otherwise.
1876        """
1877        RESAMPLE_SIZE = 32  # Reduced size to prevent memory explosion
1878        
1879        if self.dim == 1:
1880            x_vals = x_grid
1881            xi_vals = xi_grid
1882            # Resampling if necessary
1883            if len(x_vals) > RESAMPLE_SIZE:
1884                x_vals = np.linspace(x_vals.min(), x_vals.max(), RESAMPLE_SIZE)
1885            if len(xi_vals) > RESAMPLE_SIZE:
1886                xi_vals = np.linspace(xi_vals.min(), xi_vals.max(), RESAMPLE_SIZE)
1887        
1888            X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1889            symbol_vals = self.p_func(X, XI)
1890        
1891        elif self.dim == 2:
1892            x_vals, y_vals = x_grid
1893            xi_vals, eta_vals = xi_grid
1894        
1895            # Spatial resampling
1896            if len(x_vals) > RESAMPLE_SIZE:
1897                x_vals = np.linspace(x_vals.min(), x_vals.max(), RESAMPLE_SIZE)
1898            if len(y_vals) > RESAMPLE_SIZE:
1899                y_vals = np.linspace(y_vals.min(), y_vals.max(), RESAMPLE_SIZE)
1900        
1901            # Frequency resampling
1902            if len(xi_vals) > RESAMPLE_SIZE:
1903                xi_vals = np.linspace(xi_vals.min(), xi_vals.max(), RESAMPLE_SIZE)
1904            if len(eta_vals) > RESAMPLE_SIZE:
1905                eta_vals = np.linspace(eta_vals.min(), eta_vals.max(), RESAMPLE_SIZE)
1906        
1907            X, Y, XI, ETA = np.meshgrid(x_vals, y_vals, xi_vals, eta_vals, indexing='ij')
1908            symbol_vals = self.p_func(X, Y, XI, ETA)
1909        
1910        min_abs_val = np.min(np.abs(symbol_vals))
1911        return min_abs_val > threshold
1912
1913
1914    def is_self_adjoint(self, tol=1e-10):
1915        """
1916        Check whether the pseudo-differential operator is formally self-adjoint (Hermitian).
1917
1918        A self-adjoint operator satisfies P = P*, where P* is the formal adjoint of P.
1919        This property is essential for ensuring real-valued eigenvalues and stable evolution 
1920        in quantum mechanics and symmetric wave propagation.
1921
1922        Parameters
1923        ----------
1924        tol : float
1925            Tolerance for symbolic comparison between P and P*. Small numerical differences 
1926            below this threshold are considered equal.
1927
1928        Returns
1929        -------
1930        bool
1931            True if the symbol p(x, ξ) equals its formal adjoint p*(x, ξ) within the given tolerance,
1932            indicating that the operator is self-adjoint.
1933
1934        Notes:
1935        - The formal adjoint is computed via conjugation and asymptotic expansion at infinity in ξ.
1936        - Symbolic simplification is used to verify equality, ensuring robustness against superficial 
1937          expression differences.
1938        """
1939        p = self.symbol
1940        p_star = self.formal_adjoint()
1941        return simplify(p - p_star).equals(0)
1942
1943    def visualize_fiber(self, x_grid, xi_grid, x0=0.0, y0=0.0):
1944        """
1945        Plot the cotangent fiber structure at a fixed spatial point (x₀[, y₀]).
1946    
1947        This visualization shows how the symbol p(x, ξ) behaves on the cotangent fiber 
1948        above a fixed spatial point. In microlocal analysis, this provides insight into 
1949        the frequency content of the operator at that location.
1950    
1951        Parameters
1952        ----------
1953        x_grid : ndarray
1954            Spatial grid values (1D) for evaluation in 1D case.
1955        xi_grid : ndarray
1956            Frequency grid values (1D) for evaluation in both 1D and 2D cases.
1957        x0 : float, optional
1958            Fixed x-coordinate of the base point in space (1D or 2D).
1959        y0 : float, optional
1960            Fixed y-coordinate of the base point in space (2D only).
1961    
1962        Notes
1963        -----
1964        - In 1D: Displays |p(x, ξ)| over the (x, ξ) phase plane near the fixed point.
1965        - In 2D: Fixes (x₀, y₀) and evaluates p(x₀, y₀, ξ, η), showing the fiber over that point.
1966        - The color map represents the magnitude of the symbol, highlighting regions where it vanishes or becomes singular.
1967    
1968        Raises
1969        ------
1970        NotImplementedError
1971            If called in 2D with missing or improperly formatted grids.
1972        """
1973        if self.dim == 1:
1974            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1975            symbol_vals = self.p_func(X, XI)
1976            plt.contourf(X, XI, np.abs(symbol_vals), levels=50, cmap='viridis')
1977            plt.colorbar(label='|Symbol|')
1978            plt.xlabel('x (position)')
1979            plt.ylabel('ξ (frequency)')
1980            plt.title('Cotangent Fiber Structure')
1981            plt.show()
1982        elif self.dim == 2:
1983            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, xi_grid)
1984            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
1985            plt.contourf(xi_grid, xi_grid, np.abs(symbol_vals), levels=50, cmap='viridis')
1986            plt.colorbar(label='|Symbol|')
1987            plt.xlabel('ξ')
1988            plt.ylabel('η')
1989            plt.title(f'Cotangent Fiber at x={x0}, y={y0}')
1990            plt.show()
1991
1992    def visualize_symbol_amplitude(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
1993        """
1994        Display the modulus |p(x, ξ)| or |p(x, y, ξ₀, η₀)| as a color map.
1995    
1996        This method visualizes the amplitude of the pseudodifferential operator's symbol 
1997        in either 1D or 2D spatial configuration. In 2D, the frequency variables are fixed 
1998        to specified values (ξ₀, η₀) for visualization purposes.
1999    
2000        Parameters
2001        ----------
2002        x_grid, y_grid : ndarray
2003            Spatial grids over which to evaluate the symbol. y_grid is optional and used only in 2D.
2004        xi_grid, eta_grid : ndarray
2005            Frequency grids. In 2D, these define the domain over which the symbol is evaluated,
2006            but the visualization fixes ξ = ξ₀ and η = η₀.
2007        xi0, eta0 : float, optional
2008            Fixed frequency values for slicing in 2D visualization. Defaults to zero.
2009    
2010        Notes
2011        -----
2012        - In 1D: Visualizes |p(x, ξ)| over the (x, ξ) grid.
2013        - In 2D: Visualizes |p(x, y, ξ₀, η₀)| at fixed frequencies ξ₀ and η₀.
2014        - The color intensity represents the magnitude of the symbol, highlighting regions where the symbol is large or small.
2015        """
2016        if self.dim == 1:
2017            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
2018            symbol_vals = self.p_func(X, XI) 
2019            plt.pcolormesh(X, XI, np.abs(symbol_vals), shading='auto')
2020            plt.colorbar(label='|Symbol|')
2021            plt.xlabel('x')
2022            plt.ylabel('ξ')
2023            plt.title('Symbol Amplitude |p(x, ξ)|')
2024            plt.show()
2025        elif self.dim == 2:
2026            X, Y = np.meshgrid(x_grid, y_grid, indexing='ij')
2027            XI = np.full_like(X, xi0)
2028            ETA = np.full_like(Y, eta0)
2029            symbol_vals = self.p_func(X, Y, XI, ETA)
2030            plt.pcolormesh(X, Y, np.abs(symbol_vals), shading='auto')
2031            plt.colorbar(label='|Symbol|')
2032            plt.xlabel('x')
2033            plt.ylabel('y')
2034            plt.title(f'Symbol Amplitude at ξ={xi0}, η={eta0}')
2035            plt.show()
2036
2037    def visualize_phase(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
2038        """
2039        Plot the phase (argument) of the pseudodifferential operator's symbol p(x, ξ) or p(x, y, ξ, η).
2040
2041        This visualization helps in understanding the oscillatory behavior and regularity 
2042        properties of the operator in phase space. The phase is displayed modulo 2π using 
2043        a cyclic colormap ('twilight') to emphasize its periodic nature.
2044
2045        Parameters
2046        ----------
2047        x_grid : ndarray
2048            1D array of spatial coordinates (x).
2049        xi_grid : ndarray
2050            1D array of frequency coordinates (ξ).
2051        y_grid : ndarray, optional
2052            2D spatial grid for y-coordinate (in 2D problems). Default is None.
2053        eta_grid : ndarray, optional
2054            2D frequency grid for η (in 2D problems). Not used directly but kept for API consistency.
2055        xi0 : float, optional
2056            Fixed value of ξ for slicing in 2D visualization. Default is 0.0.
2057        eta0 : float, optional
2058            Fixed value of η for slicing in 2D visualization. Default is 0.0.
2059
2060        Notes:
2061        - In 1D: Displays arg(p(x, ξ)) over the (x, ξ) phase plane.
2062        - In 2D: Displays arg(p(x, y, ξ₀, η₀)) for fixed frequency values (ξ₀, η₀).
2063        - Uses plt.pcolormesh with 'twilight' colormap to represent angles from -π to π.
2064
2065        Raises:
2066        - NotImplementedError: If the spatial dimension is not 1D or 2D.
2067        """
2068        if self.dim == 1:
2069            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
2070            symbol_vals = self.p_func(X, XI) 
2071            plt.pcolormesh(X, XI, np.angle(symbol_vals), shading='auto', cmap='twilight')
2072            plt.colorbar(label='arg(Symbol) [rad]')
2073            plt.xlabel('x')
2074            plt.ylabel('ξ')
2075            plt.title('Phase Portrait (arg p(x, ξ))')
2076            plt.show()
2077        elif self.dim == 2:
2078            X, Y = np.meshgrid(x_grid, y_grid, indexing='ij')
2079            XI = np.full_like(X, xi0)
2080            ETA = np.full_like(Y, eta0)
2081            symbol_vals = self.p_func(X, Y, XI, ETA)
2082            plt.pcolormesh(X, Y, np.angle(symbol_vals), shading='auto', cmap='twilight')
2083            plt.colorbar(label='arg(Symbol) [rad]')
2084            plt.xlabel('x')
2085            plt.ylabel('y')
2086            plt.title(f'Phase Portrait at ξ={xi0}, η={eta0}')
2087            plt.show()
2088            
2089    def visualize_characteristic_set(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0, levels=[1e-1]):
2090        """
2091        Visualize the characteristic set of the pseudo-differential symbol, defined as the approximate zero set p(x, ξ) ≈ 0.
2092    
2093        In microlocal analysis, the characteristic set is the locus of points in phase space (x, ξ) where the symbol p(x, ξ) vanishes,
2094        playing a key role in understanding propagation of singularities.
2095    
2096        Parameters
2097        ----------
2098        x_grid : ndarray
2099            Spatial grid values (1D array) for plotting in 1D or evaluation point in 2D.
2100        xi_grid : ndarray
2101            Frequency variable grid values (1D array) used to construct the frequency domain.
2102        x0 : float, optional
2103            Fixed spatial coordinate in 2D case for evaluating the symbol at a specific x position.
2104        y0 : float, optional
2105            Fixed spatial coordinate in 2D case for evaluating the symbol at a specific y position.
2106    
2107        Notes
2108        -----
2109        - For 1D, this method plots the contour of |p(x, ξ)| = ε with ε = 1e-5 over the (x, ξ) plane.
2110        - For 2D, it evaluates the symbol at fixed (x₀, y₀) and plots the characteristic set in the (ξ, η) frequency plane.
2111        - This visualization helps identify directions of degeneracy or hypoellipticity of the operator.
2112    
2113        Raises
2114        ------
2115        NotImplementedError
2116            If called on a solver with dimensionality other than 1D or 2D.
2117    
2118        Displays
2119        ------
2120        A matplotlib contour plot showing either:
2121            - The characteristic curve in the (x, ξ) phase plane (1D),
2122            - The characteristic surface slice in the (ξ, η) frequency plane at (x₀, y₀) (2D).
2123        """
2124        if self.dim == 1:
2125            x_grid = np.asarray(x_grid)
2126            xi_grid = np.asarray(xi_grid)
2127            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
2128            symbol_vals = self.p_func(X, XI) 
2129            plt.contour(X, XI, np.abs(symbol_vals), levels=levels, colors='red')
2130            plt.xlabel('x')
2131            plt.ylabel('ξ')
2132            plt.title('Characteristic Set (p(x, ξ) ≈ 0)')
2133            plt.grid(True)
2134            plt.show()
2135        elif self.dim == 2:
2136            if eta_grid is None:
2137                raise ValueError("eta_grid must be provided for 2D visualization.")
2138            xi_grid = np.asarray(xi_grid)
2139            eta_grid = np.asarray(eta_grid)
2140            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, eta_grid, indexing='ij')
2141            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
2142            plt.contour(xi_grid, eta_grid, np.abs(symbol_vals), levels=levels, colors='red')
2143            plt.xlabel('ξ')
2144            plt.ylabel('η')
2145            plt.title(f'Characteristic Set at x={x0}, y={y0}')
2146            plt.grid(True)
2147            plt.show()
2148        else:
2149            raise NotImplementedError("Only 1D/2D characteristic sets supported.")
2150
2151    def visualize_characteristic_gradient(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0):
2152        """
2153        Visualize the norm of the gradient of the symbol in phase space.
2154        
2155        This method computes the magnitude of the gradient |∇p| of a pseudo-differential 
2156        symbol p(x, ξ) in 1D or p(x, y, ξ, η) in 2D. The resulting colormap reveals 
2157        regions where the symbol varies rapidly or remains nearly stationary, 
2158        which is particularly useful for analyzing characteristic sets.
2159        
2160        Parameters
2161        ----------
2162        x_grid : numpy.ndarray
2163            1D array of spatial coordinates for the x-direction.
2164        xi_grid : numpy.ndarray
2165            1D array of frequency coordinates (ξ).
2166        y_grid : numpy.ndarray, optional
2167            1D array of spatial coordinates for the y-direction (used in 2D mode). Default is None.
2168        eta_grid : numpy.ndarray, optional
2169            1D array of frequency coordinates (η) for the 2D case. Default is None.
2170        x0 : float, optional
2171            Fixed x-coordinate for evaluating the symbol in 2D. Default is 0.0.
2172        y0 : float, optional
2173            Fixed y-coordinate for evaluating the symbol in 2D. Default is 0.0.
2174        
2175        Returns
2176        -------
2177        None
2178            Displays a 2D colormap of |∇p| over the relevant phase-space domain.
2179        
2180        Notes
2181        -----
2182        - In 1D, the full gradient ∇p = (∂ₓp, ∂ξp) is computed over the (x, ξ) grid.
2183        - In 2D, the gradient ∇p = (∂ξp, ∂ηp) is computed at a fixed spatial point (x₀, y₀) over the (ξ, η) grid.
2184        - Numerical differentiation is performed using `np.gradient`.
2185        - High values of |∇p| indicate rapid variation of the symbol, while low values typically suggest characteristic regions.
2186        """
2187        if self.dim == 1:
2188            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
2189            symbol_vals = self.p_func(X, XI)
2190            grad_x = np.gradient(symbol_vals, axis=0)
2191            grad_xi = np.gradient(symbol_vals, axis=1)
2192            grad_norm = np.sqrt(grad_x**2 + grad_xi**2)
2193            plt.pcolormesh(X, XI, grad_norm, cmap='inferno', shading='auto')
2194            plt.colorbar(label='|∇p|')
2195            plt.xlabel('x')
2196            plt.ylabel('ξ')
2197            plt.title('Gradient Norm (High Near Zeros)')
2198            plt.grid(True)
2199            plt.show()
2200        elif self.dim == 2:
2201            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, eta_grid, indexing='ij')
2202            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
2203            grad_xi = np.gradient(symbol_vals, axis=0)
2204            grad_eta = np.gradient(symbol_vals, axis=1)
2205            grad_norm = np.sqrt(np.abs(grad_xi)**2 + np.abs(grad_eta)**2)
2206            plt.pcolormesh(xi_grid, eta_grid, grad_norm, cmap='inferno', shading='auto')
2207            plt.colorbar(label='|∇p|')
2208            plt.xlabel('ξ')
2209            plt.ylabel('η')
2210            plt.title(f'Gradient Norm at x={x0}, y={y0}')
2211            plt.grid(True)
2212            plt.show()
2213
2214    def plot_hamiltonian_flow(self, x0=0.0, xi0=5.0, y0=0.0, eta0=0.0, tmax=1.0, n_steps=100, show_field=True):
2215        """
2216        Integrate and plot the Hamiltonian trajectories of the symbol in phase space.
2217
2218        This method numerically integrates the Hamiltonian vector field derived from 
2219        the operator's symbol to visualize how singularities propagate under the flow. 
2220        It supports both 1D and 2D problems.
2221
2222        Parameters
2223        ----------
2224        x0, xi0 : float
2225            Initial position and frequency (momentum) in 1D.
2226        y0, eta0 : float, optional
2227            Initial position and frequency in 2D; defaults to zero.
2228        tmax : float
2229            Final integration time for the ODE solver.
2230        n_steps : int
2231            Number of time steps used in the integration.
2232
2233        Notes
2234        -----
2235        - The Hamiltonian vector field is obtained from the symplectic flow of the symbol.
2236        - If the field is complex-valued, only its real part is used for integration.
2237        - In 1D, the trajectory is plotted in (x, ξ) phase space.
2238        - In 2D, the spatial trajectory (x(t), y(t)) is shown along with instantaneous 
2239          momentum vectors (ξ(t), η(t)) using a quiver plot.
2240
2241        Raises
2242        ------
2243        NotImplementedError
2244            If the spatial dimension is not 1D or 2D.
2245
2246        Displays
2247        --------
2248        matplotlib plot
2249            Phase space trajectory(ies) showing the evolution of position and momentum 
2250            under the Hamiltonian dynamics.
2251        """
2252        def make_real(expr):
2253            from sympy import re, simplify
2254            expr = expr.doit(deep=True)
2255            return simplify(re(expr))
2256    
2257        H = self.symplectic_flow()
2258    
2259        if any(im(H[k]) != 0 for k in H):
2260            print("⚠️ The Hamiltonian field is complex. Only the real part is used for integration.")
2261    
2262        if self.dim == 1:
2263            x, = self.vars_x
2264            xi = symbols('xi', real=True)
2265    
2266            dxdt_expr = make_real(H['dx/dt'])
2267            dxidt_expr = make_real(H['dxi/dt'])
2268    
2269            dxdt = lambdify((x, xi), dxdt_expr, 'numpy')
2270            dxidt = lambdify((x, xi), dxidt_expr, 'numpy')
2271    
2272            def hamilton(t, Y):
2273                x, xi = Y
2274                return [dxdt(x, xi), dxidt(x, xi)]
2275    
2276            sol = solve_ivp(hamilton, [0, tmax], [x0, xi0], t_eval=np.linspace(0, tmax, n_steps))
2277
2278            if sol.status != 0:
2279                print(f"⚠️ Integration warning: {sol.message}")
2280            
2281            n_points = sol.y.shape[1]
2282            if n_points < n_steps:
2283                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2284                n_steps = n_points
2285
2286            x_vals, xi_vals = sol.y
2287    
2288            plt.plot(x_vals, xi_vals)
2289            plt.xlabel("x")
2290            plt.ylabel("ξ")
2291            plt.title("Hamiltonian Flow in Phase Space (1D)")
2292            plt.grid(True)
2293            plt.show()
2294    
2295        elif self.dim == 2:
2296            x, y = self.vars_x
2297            xi, eta = symbols('xi eta', real=True)
2298    
2299            dxdt = lambdify((x, y, xi, eta), make_real(H['dx/dt']), 'numpy')
2300            dydt = lambdify((x, y, xi, eta), make_real(H['dy/dt']), 'numpy')
2301            dxidt = lambdify((x, y, xi, eta), make_real(H['dxi/dt']), 'numpy')
2302            detadt = lambdify((x, y, xi, eta), make_real(H['deta/dt']), 'numpy')
2303    
2304            def hamilton(t, Y):
2305                x, y, xi, eta = Y
2306                return [
2307                    dxdt(x, y, xi, eta),
2308                    dydt(x, y, xi, eta),
2309                    dxidt(x, y, xi, eta),
2310                    detadt(x, y, xi, eta)
2311                ]
2312    
2313            sol = solve_ivp(hamilton, [0, tmax], [x0, y0, xi0, eta0], t_eval=np.linspace(0, tmax, n_steps))
2314
2315            if sol.status != 0:
2316                print(f"⚠️ Integration warning: {sol.message}")
2317            
2318            n_points = sol.y.shape[1]
2319            if n_points < n_steps:
2320                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2321                n_steps = n_points
2322
2323            x_vals, y_vals, xi_vals, eta_vals = sol.y
2324    
2325            plt.plot(x_vals, y_vals, label='Position')
2326            plt.quiver(x_vals, y_vals, xi_vals, eta_vals, scale=20, width=0.003, alpha=0.5, color='r')
2327            
2328            # Vector field of the flow (optional)
2329            if show_field:
2330                X, Y = np.meshgrid(np.linspace(min(x_vals), max(x_vals), 20),
2331                                   np.linspace(min(y_vals), max(y_vals), 20))
2332                XI, ETA = xi0 * np.ones_like(X), eta0 * np.ones_like(Y)
2333                U = dxdt(X, Y, XI, ETA)
2334                V = dydt(X, Y, XI, ETA)
2335                plt.quiver(X, Y, U, V, color='gray', alpha=0.2, scale=30, width=0.002)
2336
2337            plt.xlabel("x")
2338            plt.ylabel("y")
2339            plt.title("Hamiltonian Flow in Phase Space (2D)")
2340            plt.legend()
2341            plt.grid(True)
2342            plt.axis('equal')
2343            plt.show()
2344
2345    def plot_symplectic_vector_field(self, xlim=(-2, 2), klim=(-5, 5), density=30):
2346        """
2347        Visualize the symplectic vector field (Hamiltonian vector field) associated with the operator's symbol.
2348
2349        The plotted vector field corresponds to (∂_ξ p, -∂_x p), where p(x, ξ) is the principal symbol 
2350        of the pseudo-differential operator. This field governs the bicharacteristic flow in phase space.
2351
2352        Parameters
2353        ----------
2354        xlim : tuple of float
2355            Range for spatial variable x, as (x_min, x_max).
2356        klim : tuple of float
2357            Range for frequency variable ξ, as (ξ_min, ξ_max).
2358        density : int
2359            Number of grid points per axis for the visualization grid.
2360
2361        Raises
2362        ------
2363        NotImplementedError
2364            If called on a 2D operator (currently only 1D implementation available).
2365
2366        Notes
2367        -----
2368        - Only supports one-dimensional operators.
2369        - Uses symbolic differentiation to compute ∂_ξ p and ∂_x p.
2370        - Numerical evaluation is done via lambdify with NumPy backend.
2371        - Visualization uses matplotlib quiver plot to show vector directions.
2372        """
2373        x_vals = np.linspace(*xlim, density)
2374        xi_vals = np.linspace(*klim, density)
2375        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
2376
2377        if self.dim != 1:
2378            raise NotImplementedError("Only 1D version implemented.")
2379
2380        x, = self.vars_x
2381        xi = symbols('xi', real=True)
2382        H = self.symplectic_flow()
2383        dxdt = lambdify((x, xi), simplify(H['dx/dt']), 'numpy')
2384        dxidt = lambdify((x, xi), simplify(H['dxi/dt']), 'numpy')
2385
2386        U = dxdt(X, XI)
2387        V = dxidt(X, XI)
2388
2389        plt.quiver(X, XI, U, V, scale=10, width=0.005)
2390        plt.xlabel('x')
2391        plt.ylabel(r'$\xi$')
2392        plt.title("Symplectic Vector Field (1D)")
2393        plt.grid(True)
2394        plt.show()
2395
2396    def visualize_micro_support(self, xlim=(-2, 2), klim=(-10, 10), threshold=1e-3, density=300):
2397        """
2398        Visualize the micro-support of the operator by plotting the inverse of the symbol magnitude 1 / |p(x, ξ)|.
2399    
2400        The micro-support provides insight into the singularities of a pseudo-differential operator 
2401        in phase space (x, ξ). Regions where |p(x, ξ)| is small correspond to large values in 1/|p(x, ξ)|,
2402        highlighting areas of significant operator influence or singularity.
2403    
2404        Parameters
2405        ----------
2406        xlim : tuple
2407            Spatial domain limits (x_min, x_max).
2408        klim : tuple
2409            Frequency domain limits (ξ_min, ξ_max).
2410        threshold : float
2411            Threshold below which |p(x, ξ)| is considered effectively zero; used for numerical stability.
2412        density : int
2413            Number of grid points along each axis for visualization resolution.
2414    
2415        Raises
2416        ------
2417        NotImplementedError
2418            If called on a solver with dimension greater than 1 (only 1D visualization is supported).
2419    
2420        Notes
2421        -----
2422        - This method evaluates the symbol p(x, ξ) over a grid and plots its reciprocal to emphasize 
2423          regions where the symbol is near zero.
2424        - A small constant (1e-10) is added to the denominator to avoid division by zero.
2425        - The resulting plot helps identify characteristic sets.
2426        """
2427        if self.dim != 1:
2428            raise NotImplementedError("Only 1D micro-support visualization implemented.")
2429
2430        x_vals = np.linspace(*xlim, density)
2431        xi_vals = np.linspace(*klim, density)
2432        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
2433        Z = np.abs(self.p_func(X, XI))
2434
2435        plt.contourf(X, XI, 1 / (Z + 1e-10), levels=100, cmap='inferno')
2436        plt.colorbar(label=r'$1/|p(x,\xi)|$')
2437        plt.xlabel('x')
2438        plt.ylabel(r'$\xi$')
2439        plt.title("Micro-Support Estimate (1/|Symbol|)")
2440        plt.show()
2441
2442    def group_velocity_field(self, xlim=(-2, 2), klim=(-10, 10), density=30):
2443        """
2444        Plot the group velocity field ∇_ξ p(x, ξ) for 1D pseudo-differential operators.
2445
2446        The group velocity represents the speed at which waves of different frequencies propagate 
2447        in a dispersive medium. It is defined as the gradient of the symbol p(x, ξ) with respect 
2448        to the frequency variable ξ.
2449
2450        Parameters
2451        ----------
2452        xlim : tuple of float
2453            Spatial domain limits (x-axis).
2454        klim : tuple of float
2455            Frequency domain limits (ξ-axis).
2456        density : int
2457            Number of grid points per axis used for visualization.
2458
2459        Raises
2460        ------
2461        NotImplementedError
2462            If called on a 2D operator, since this visualization is only implemented for 1D.
2463
2464        Notes
2465        -----
2466        - This method visualizes the vector field (∂p/∂ξ) in phase space.
2467        - Used for analyzing wave propagation properties and dispersion relations.
2468        - Requires symbolic expression self.expr depending on x and ξ.
2469        """
2470        if self.dim != 1:
2471            raise NotImplementedError("Only 1D group velocity visualization implemented.")
2472
2473        x, = self.vars_x
2474        xi = symbols('xi', real=True)
2475        dp_dxi = diff(self.symbol, xi)
2476        grad_func = lambdify((x, xi), dp_dxi, 'numpy')
2477
2478        x_vals = np.linspace(*xlim, density)
2479        xi_vals = np.linspace(*klim, density)
2480        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
2481        V = grad_func(X, XI)
2482
2483        plt.quiver(X, XI, np.ones_like(V), V, scale=10, width=0.004)
2484        plt.xlabel('x')
2485        plt.ylabel(r'$\xi$')
2486        plt.title("Group Velocity Field (1D)")
2487        plt.grid(True)
2488        plt.show()
2489
2490    def animate_singularity(self, xi0=5.0, eta0=0.0, x0=0.0, y0=0.0,
2491                            tmax=4.0, n_frames=100, projection=None):
2492        """
2493        Animate the propagation of a singularity under the Hamiltonian flow.
2494
2495        This method visualizes how a singularity (x₀, y₀, ξ₀, η₀) evolves in phase space 
2496        according to the Hamiltonian dynamics induced by the principal symbol of the operator.
2497        The animation integrates the Hamiltonian equations of motion and supports various projections:
2498        position (x-y), frequency (ξ-η), or mixed phase space coordinates.
2499
2500        Parameters
2501        ----------
2502        xi0, eta0 : float
2503            Initial frequency components (ξ₀, η₀).
2504        x0, y0 : float
2505            Initial spatial coordinates (x₀, y₀).
2506        tmax : float
2507            Total time of integration (final animation time).
2508        n_frames : int
2509            Number of frames in the resulting animation.
2510        projection : str or None
2511            Type of projection to display:
2512                - 'position' : x vs y (or x alone in 1D)
2513                - 'frequency': ξ vs η (or ξ alone in 1D)
2514                - 'phase'    : mixed coordinates like x vs ξ or x vs η
2515                If None, defaults to 'phase' in 1D and 'position' in 2D.
2516
2517        Returns
2518        -------
2519        matplotlib.animation.FuncAnimation
2520            Animation object that can be displayed interactively in Jupyter notebooks or saved as a video.
2521
2522        Notes
2523        -----
2524        - In 1D, only one spatial and one frequency variable are used.
2525        - Complex-valued Hamiltonian fields are truncated to their real parts for integration.
2526        - Trajectories are shown with both instantaneous position (dot) and full path (dashed line).
2527        """
2528        rc('animation', html='jshtml')
2529    
2530        def make_real(expr):
2531            from sympy import re, simplify
2532            expr = expr.doit(deep=True)
2533            return simplify(re(expr))
2534  
2535        H = self.symplectic_flow()
2536
2537        H = {k: v.doit(deep=True) for k, v in H.items()}
2538
2539        print("H = ", H)
2540    
2541        if any(im(H[k]) != 0 for k in H):
2542            print("⚠️ The Hamiltonian field is complex. Only the real part is used for integration.")
2543    
2544        if self.dim == 1:
2545            x, = self.vars_x
2546            xi = symbols('xi', real=True)
2547    
2548            dxdt = lambdify((x, xi), make_real(H['dx/dt']), 'numpy')
2549            dxidt = lambdify((x, xi), make_real(H['dxi/dt']), 'numpy')
2550    
2551            def hamilton(t, Y):
2552                x, xi = Y
2553                return [dxdt(x, xi), dxidt(x, xi)]
2554    
2555            sol = solve_ivp(hamilton, [0, tmax], [x0, xi0],
2556                            t_eval=np.linspace(0, tmax, n_frames))
2557            
2558            if sol.status != 0:
2559                print(f"⚠️ Integration warning: {sol.message}")
2560            
2561            n_points = sol.y.shape[1]
2562            if n_points < n_frames:
2563                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2564                n_frames = n_points
2565
2566            x_vals, xi_vals = sol.y
2567    
2568            if projection is None:
2569                projection = 'phase'
2570    
2571            fig, ax = plt.subplots()
2572            point, = ax.plot([], [], 'ro')
2573            traj, = ax.plot([], [], 'b--', lw=1, alpha=0.5)
2574    
2575            if projection == 'phase':
2576                ax.set_xlabel('x')
2577                ax.set_ylabel(r'$\xi$')
2578                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2579                ax.set_ylim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2580    
2581                def update(i):
2582                    point.set_data([x_vals[i]], [xi_vals[i]])
2583                    traj.set_data(x_vals[:i+1], xi_vals[:i+1])
2584                    return point, traj
2585    
2586            elif projection == 'position':
2587                ax.set_xlabel('x')
2588                ax.set_ylabel('x')
2589                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2590                ax.set_ylim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2591    
2592                def update(i):
2593                    point.set_data([x_vals[i]], [x_vals[i]])
2594                    traj.set_data(x_vals[:i+1], x_vals[:i+1])
2595                    return point, traj
2596    
2597            elif projection == 'frequency':
2598                ax.set_xlabel(r'$\xi$')
2599                ax.set_ylabel(r'$\xi$')
2600                ax.set_xlim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2601                ax.set_ylim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2602    
2603                def update(i):
2604                    point.set_data([xi_vals[i]], [xi_vals[i]])
2605                    traj.set_data(xi_vals[:i+1], xi_vals[:i+1])
2606                    return point, traj
2607    
2608            else:
2609                raise ValueError("Invalid projection mode")
2610    
2611            ax.set_title(f"1D Singularity Flow ({projection})")
2612            ax.grid(True)
2613            ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50)
2614            plt.close(fig)
2615            return ani
2616    
2617        elif self.dim == 2:
2618            x, y = self.vars_x
2619            xi, eta = symbols('xi eta', real=True)
2620    
2621            dxdt = lambdify((x, y, xi, eta), make_real(H['dx/dt']), 'numpy')
2622            dydt = lambdify((x, y, xi, eta), make_real(H['dy/dt']), 'numpy')
2623            dxidt = lambdify((x, y, xi, eta), make_real(H['dxi/dt']), 'numpy')
2624            detadt = lambdify((x, y, xi, eta), make_real(H['deta/dt']), 'numpy')
2625    
2626            def hamilton(t, Y):
2627                x, y, xi, eta = Y
2628                return [
2629                    dxdt(x, y, xi, eta),
2630                    dydt(x, y, xi, eta),
2631                    dxidt(x, y, xi, eta),
2632                    detadt(x, y, xi, eta)
2633                ]
2634    
2635            sol = solve_ivp(hamilton, [0, tmax], [x0, y0, xi0, eta0],
2636                            t_eval=np.linspace(0, tmax, n_frames))
2637
2638            if sol.status != 0:
2639                print(f"⚠️ Integration warning: {sol.message}")
2640            
2641            n_points = sol.y.shape[1]
2642            if n_points < n_frames:
2643                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2644                n_frames = n_points
2645                
2646            x_vals, y_vals, xi_vals, eta_vals = sol.y
2647    
2648            if projection is None:
2649                projection = 'position'
2650    
2651            fig, ax = plt.subplots()
2652            point, = ax.plot([], [], 'ro')
2653            traj, = ax.plot([], [], 'b--', lw=1, alpha=0.5)
2654    
2655            if projection == 'position':
2656                ax.set_xlabel('x')
2657                ax.set_ylabel('y')
2658                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2659                ax.set_ylim(np.min(y_vals) - 1, np.max(y_vals) + 1)
2660    
2661                def update(i):
2662                    point.set_data([x_vals[i]], [y_vals[i]])
2663                    traj.set_data(x_vals[:i+1], y_vals[:i+1])
2664                    return point, traj
2665    
2666            elif projection == 'frequency':
2667                ax.set_xlabel(r'$\xi$')
2668                ax.set_ylabel(r'$\eta$')
2669                ax.set_xlim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2670                ax.set_ylim(np.min(eta_vals) - 1, np.max(eta_vals) + 1)
2671    
2672                def update(i):
2673                    point.set_data([xi_vals[i]], [eta_vals[i]])
2674                    traj.set_data(xi_vals[:i+1], eta_vals[:i+1])
2675                    return point, traj
2676    
2677            elif projection == 'phase':
2678                ax.set_xlabel('x')
2679                ax.set_ylabel(r'$\eta$')
2680                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2681                ax.set_ylim(np.min(eta_vals) - 1, np.max(eta_vals) + 1)
2682    
2683                def update(i):
2684                    point.set_data([x_vals[i]], [eta_vals[i]])
2685                    traj.set_data(x_vals[:i+1], eta_vals[:i+1])
2686                    return point, traj
2687    
2688            else:
2689                raise ValueError("Invalid projection mode")
2690    
2691            ax.set_title(f"2D Singularity Flow ({projection})")
2692            ax.grid(True)
2693            ax.axis('equal')
2694            ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50)
2695            plt.close(fig)
2696            return ani
2697
2698    def interactive_symbol_analysis(pseudo_op,
2699                                    xlim=(-2, 2), ylim=(-2, 2),
2700                                    xi_range=(0.1, 5), eta_range=(-5, 5),
2701                                    density=100):
2702        """
2703        Launch an interactive dashboard for symbol exploration using ipywidgets.
2704    
2705        This function provides a user-friendly interface to visualize various aspects of the pseudo-differential operator's symbol.
2706        It supports multiple visualization modes in both 1D and 2D, including group velocity fields, micro-support estimates,
2707        symplectic vector fields, symbol amplitude/phase, cotangent fiber structure, characteristic sets and Hamiltonian flows.
2708    
2709        Parameters
2710        ----------
2711        pseudo_op : PseudoDifferentialOperator
2712            The pseudo-differential operator whose symbol is to be analyzed interactively.
2713        xlim, ylim : tuple of float
2714            Spatial domain limits along x and y axes respectively.
2715        xi_range, eta_range : tuple
2716            Frequency domain limits along ξ and η axes respectively.
2717        density : int
2718            Number of points per axis used to construct the evaluation grid. Controls resolution.
2719    
2720        Notes
2721        -----
2722        - In 1D mode, sliders control the fixed frequency (ξ₀) and spatial position (x₀).
2723        - In 2D mode, additional sliders control the second frequency component (η₀) and second spatial coordinate (y₀).
2724        - Visualization updates dynamically as parameters are adjusted via sliders or dropdown menus.
2725        - Supported visualization modes:
2726            'Symbol Amplitude'           : |p(x,ξ)| or |p(x,y,ξ,η)|
2727            'Symbol Phase'               : arg(p(x,ξ)) or similar in 2D
2728            'Micro-Support (1/|p|)'      : Reciprocal of symbol magnitude
2729            'Cotangent Fiber'            : Structure of symbol over frequency space at fixed x
2730            'Characteristic Set'         : Zero set approximation {p ≈ 0}
2731            'Characteristic Gradient'    : |∇p(x, ξ)| or |∇p(x₀, y₀, ξ, η)|
2732            'Group Velocity Field'       : ∇_ξ p(x,ξ) or ∇_{ξ,η} p(x,y,ξ,η)
2733            'Symplectic Vector Field'    : (∇_ξ p, -∇_x p) or similar in 2D
2734            'Hamiltonian Flow'           : Trajectories generated by the Hamiltonian vector field
2735    
2736        Raises
2737        ------
2738        NotImplementedError
2739            If the spatial dimension is not 1D or 2D.
2740    
2741        Prints
2742        ------
2743        Interactive matplotlib figures with dynamic updates based on widget inputs.
2744        """
2745        dim = pseudo_op.dim
2746        expr = pseudo_op.expr
2747        vars_x = pseudo_op.vars_x
2748    
2749        mode_selector_1D = Dropdown(
2750            options=[
2751                'Symbol Amplitude',
2752                'Symbol Phase',
2753                'Micro-Support (1/|p|)',
2754                'Cotangent Fiber',
2755                'Characteristic Set',
2756                'Characteristic Gradient',
2757                'Group Velocity Field',
2758                'Symplectic Vector Field',
2759                'Hamiltonian Flow',
2760            ],
2761            value='Symbol Amplitude',
2762            description='Mode:'
2763        )
2764
2765        mode_selector_2D = Dropdown(
2766            options=[
2767                'Symbol Amplitude',
2768                'Symbol Phase',
2769                'Micro-Support (1/|p|)',
2770                'Cotangent Fiber',
2771                'Characteristic Set',
2772                'Characteristic Gradient',
2773                'Symplectic Vector Field',
2774                'Hamiltonian Flow',
2775            ],
2776            value='Symbol Amplitude',
2777            description='Mode:'
2778        )
2779    
2780        x_vals = np.linspace(*xlim, density)
2781        if dim == 2:
2782            y_vals = np.linspace(*ylim, density)
2783    
2784        if dim == 1:
2785            x, = vars_x
2786            xi = symbols('xi', real=True)
2787            grad_func = lambdify((x, xi), diff(expr, xi), 'numpy')
2788            symplectic_func = lambdify((x, xi), [diff(expr, xi), -diff(expr, x)], 'numpy')
2789            symbol_func = lambdify((x, xi), expr, 'numpy')
2790
2791            xi_slider = FloatSlider(min=xi_range[0], max=xi_range[1], step=0.1, value=1.0, description='ξ₀')
2792            x_slider = FloatSlider(min=xlim[0], max=xlim[1], step=0.1, value=0.0, description='x₀')
2793    
2794            def plot_1d(mode, xi0, x0):
2795                X = x_vals[:, None]
2796    
2797                if mode == 'Group Velocity Field':
2798                    V = grad_func(X, xi0)
2799                    plt.quiver(X, V, np.ones_like(V), V, scale=10, width=0.004)
2800                    plt.xlabel('x')
2801                    plt.title(f'Group Velocity Field at ξ={xi0:.2f}')
2802    
2803                elif mode == 'Micro-Support (1/|p|)':
2804                    Z = 1 / (np.abs(symbol_func(X, xi0)) + 1e-10)
2805                    plt.plot(x_vals, Z)
2806                    plt.xlabel('x')
2807                    plt.title(f'Micro-Support (1/|p|) at ξ={xi0:.2f}')
2808    
2809                elif mode == 'Symplectic Vector Field':
2810                    U, V = symplectic_func(X, xi0)
2811                    plt.quiver(X, V, U, V, scale=10, width=0.004)
2812                    plt.xlabel('x')
2813                    plt.title(f'Symplectic Field at ξ={xi0:.2f}')
2814    
2815                elif mode == 'Symbol Amplitude':
2816                    Z = np.abs(symbol_func(X, xi0))
2817                    plt.plot(x_vals, Z)
2818                    plt.xlabel('x')
2819                    plt.title(f'Symbol Amplitude |p(x,ξ)| at ξ={xi0:.2f}')
2820    
2821                elif mode == 'Symbol Phase':
2822                    Z = np.angle(symbol_func(X, xi0))
2823                    plt.plot(x_vals, Z)
2824                    plt.xlabel('x')
2825                    plt.title(f'Symbol Phase arg(p(x,ξ)) at ξ={xi0:.2f}')
2826    
2827                elif mode == 'Cotangent Fiber':
2828                    pseudo_op.visualize_fiber(x_vals, np.linspace(*xi_range, density), x0=x0)
2829    
2830                elif mode == 'Characteristic Set':
2831                    pseudo_op.visualize_characteristic_set(x_vals, np.linspace(*xi_range, density), x0=x0)
2832    
2833                elif mode == 'Characteristic Gradient':
2834                    pseudo_op.visualize_characteristic_gradient(x_vals, np.linspace(*xi_range, density), x0=x0)
2835    
2836                elif mode == 'Hamiltonian Flow':
2837                    pseudo_op.plot_hamiltonian_flow(x0=x0, xi0=xi0)
2838    
2839            # --- Dynamic container for sliders ---
2840            controls_box = VBox([mode_selector_1D, xi_slider, x_slider])
2841            # --- Function to adjust visible sliders based on mode ---
2842            def update_controls(change):
2843                mode = change['new']
2844                # modes that depend only on xi and eta
2845                if mode in ['Symbol Amplitude', 'Symbol Phase', 'Micro-Support (1/|p|)',
2846                            'Group Velocity Field', 'Symplectic Vector Field']:
2847                    controls_box.children = [mode_selector_1D, xi_slider]
2848                # modes that require xi and x
2849                elif mode in ['Hamiltonian Flow']:
2850                    controls_box.children = [mode_selector_1D, xi_slider, x_slider]
2851                # modes that require nothing
2852                elif mode in ['Cotangent Fiber', 'Characteristic Set', 'Characteristic Gradient']:
2853                    controls_box.children = [mode_selector_1D]
2854            mode_selector_1D.observe(update_controls, names='value')
2855            update_controls({'new': mode_selector_1D.value}) 
2856            # --- Interactive binding ---
2857            out = interactive_output(plot_1d, {'mode': mode_selector_1D, 'xi0': xi_slider, 'x0': x_slider})
2858            display(VBox([controls_box, out]))
2859
2860        elif dim == 2:
2861            x, y = vars_x
2862            xi, eta = symbols('xi eta', real=True)
2863            symplectic_func = lambdify((x, y, xi, eta), [diff(expr, xi), diff(expr, eta)], 'numpy')
2864            symbol_func = lambdify((x, y, xi, eta), expr, 'numpy')
2865
2866            xi_slider=FloatSlider(min=xi_range[0], max=xi_range[1], step=0.1, value=1.0, description='ξ₀')
2867            eta_slider=FloatSlider(min=eta_range[0], max=eta_range[1], step=0.1, value=1.0, description='η₀')
2868            x_slider=FloatSlider(min=xlim[0], max=xlim[1], step=0.1, value=0.0, description='x₀')
2869            y_slider=FloatSlider(min=ylim[0], max=ylim[1], step=0.1, value=0.0, description='y₀')
2870    
2871            def plot_2d(mode, xi0, eta0, x0, y0):
2872                X, Y = np.meshgrid(x_vals, y_vals, indexing='ij')
2873    
2874                if mode == 'Micro-Support (1/|p|)':
2875                    Z = 1 / (np.abs(symbol_func(X, Y, xi0, eta0)) + 1e-10)
2876                    plt.pcolormesh(X, Y, Z, shading='auto', cmap='inferno')
2877                    plt.colorbar(label='1/|p|')
2878                    plt.xlabel('x')
2879                    plt.ylabel('y')
2880                    plt.title(f'Micro-Support at ξ={xi0:.2f}, η={eta0:.2f}')
2881    
2882                elif mode == 'Symplectic Vector Field':
2883                    U, V = symplectic_func(X, Y, xi0, eta0)
2884                    plt.quiver(X, Y, U, V, scale=10, width=0.004)
2885                    plt.xlabel('x')
2886                    plt.ylabel('y')
2887                    plt.title(f'Symplectic Field at ξ={xi0:.2f}, η={eta0:.2f}')
2888    
2889                elif mode == 'Symbol Amplitude':
2890                    Z = np.abs(symbol_func(X, Y, xi0, eta0))
2891                    plt.pcolormesh(X, Y, Z, shading='auto')
2892                    plt.colorbar(label='|p(x,y,ξ,η)|')
2893                    plt.xlabel('x')
2894                    plt.ylabel('y')
2895                    plt.title(f'Symbol Amplitude at ξ={xi0:.2f}, η={eta0:.2f}')
2896    
2897                elif mode == 'Symbol Phase':
2898                    Z = np.angle(symbol_func(X, Y, xi0, eta0))
2899                    plt.pcolormesh(X, Y, Z, shading='auto', cmap='twilight')
2900                    plt.colorbar(label='arg(p)')
2901                    plt.xlabel('x')
2902                    plt.ylabel('y')
2903                    plt.title(f'Symbol Phase at ξ={xi0:.2f}, η={eta0:.2f}')
2904    
2905                elif mode == 'Cotangent Fiber':
2906                    pseudo_op.visualize_fiber(np.linspace(*xi_range, density), np.linspace(*eta_range, density),
2907                                              x0=x0, y0=y0)
2908    
2909                elif mode == 'Characteristic Set':
2910                    pseudo_op.visualize_characteristic_set(x_grid=x_vals, xi_grid=np.linspace(*xi_range, density),
2911                                                  y_grid=y_vals, eta_grid=np.linspace(*eta_range, density), x0=x0, y0=y0)
2912    
2913                elif mode == 'Characteristic Gradient':
2914                    pseudo_op.visualize_characteristic_gradient(x_grid=x_vals, xi_grid=np.linspace(*xi_range, density),
2915                                                  y_grid=y_vals, eta_grid=np.linspace(*eta_range, density), x0=x0, y0=y0)
2916    
2917                elif mode == 'Hamiltonian Flow':
2918                    pseudo_op.plot_hamiltonian_flow(x0=x0, y0=y0, xi0=xi0, eta0=eta0)
2919                    
2920            # --- Dynamic container for sliders ---
2921            controls_box = VBox([mode_selector_2D, xi_slider, eta_slider, x_slider, y_slider])
2922            # --- Function to adjust visible sliders based on mode ---
2923            def update_controls(change):
2924                mode = change['new']
2925                # modes that depend only on xi
2926                if mode in ['Symbol Amplitude', 'Symbol Phase', 'Micro-Support (1/|p|)', 'Symplectic Vector Field']:
2927                    controls_box.children = [mode_selector_2D, xi_slider, eta_slider]
2928                # modes that require xi, eta, x and y
2929                elif mode in ['Hamiltonian Flow']:
2930                    controls_box.children = [mode_selector_2D, xi_slider, eta_slider, x_slider, y_slider]
2931                # modes that require x and y
2932                elif mode in ['Cotangent Fiber', 'Characteristic Set', 'Characteristic Gradient']:
2933                    controls_box.children = [mode_selector_2D, x_slider, y_slider]
2934            mode_selector_2D.observe(update_controls, names='value')
2935            update_controls({'new': mode_selector_2D.value}) 
2936            # --- Interactive binding ---
2937            out = interactive_output(plot_2d, {'mode': mode_selector_2D, 'xi0': xi_slider, 'eta0': eta_slider, 'x0': x_slider, 'y0': y_slider})
2938            display(VBox([controls_box, out]))

Pseudo-differential operator with dynamic symbol evaluation on spatial grids. Supports both 1D and 2D operators, and can be defined explicitly (symbol mode) or extracted automatically from symbolic equations (auto mode).

Parameters

expr : sympy expression Symbolic expression representing the pseudo-differential symbol. vars_x : list of sympy symbols Spatial variables (e.g., [x] for 1D, [x, y] for 2D). var_u : sympy function, optional Function u(x, t) used in auto mode to extract the operator symbol. mode : str, {'symbol', 'auto'} - 'symbol': directly uses expr as the operator symbol. - 'auto': computes the symbol automatically by applying expr to exp(i x ξ).

Attributes

dim : int Spatial dimension (1 or 2). fft, ifft : callable Fast Fourier transform and inverse (scipy.fft or scipy.fft2). p_func : callable Evaluated symbol function ready for numerical use.

Notes

  • In 'symbol' mode, expr should be expressed in terms of spatial variables and frequency variables (ξ, η).
  • In 'auto' mode, the symbol is derived by applying the differential expression to a complex exponential.
  • Frequency variables are internally named 'xi' and 'eta' for consistency.
  • Uses numpy for numerical evaluation and scipy.fft for FFT operations.

Examples

>>> # Example 1: 1D Laplacian operator (symbol mode)
>>> from sympy import symbols
>>> x, xi = symbols('x xi', real=True)
>>> op = PseudoDifferentialOperator(expr=xi**2, vars_x=[x], mode='symbol')
>>> # Example 2: 1D transport operator (auto mode)
>>> from sympy import Function
>>> u = Function('u')
>>> expr = u(x).diff(x)
>>> op = PseudoDifferentialOperator(expr=expr, vars_x=[x], var_u=u(x), mode='auto')
PseudoDifferentialOperator(expr, vars_x, var_u=None, mode='symbol')
 76    def __init__(self, expr, vars_x, var_u=None, mode='symbol'):
 77        self.dim = len(vars_x)
 78        self.mode = mode
 79        self.symbol_cached = None
 80        self.expr = expr
 81        self.vars_x = vars_x
 82
 83        if self.dim == 1:
 84            x, = vars_x
 85            xi_internal = symbols('xi', real=True)
 86            expr = expr.subs(symbols('xi', real=True), xi_internal)
 87            self.fft = partial(fft, workers=FFT_WORKERS)
 88            self.ifft = partial(ifft, workers=FFT_WORKERS)
 89
 90            if mode == 'symbol':
 91                self.p_func = lambdify((x, xi_internal), expr, 'numpy')
 92                self.symbol = expr
 93            elif mode == 'auto':
 94                if var_u is None:
 95                    raise ValueError("var_u must be provided in mode='auto'")
 96                exp_i = exp(I * x * xi_internal)
 97                P_ei = expr.subs(var_u, exp_i)
 98                symbol = simplify(P_ei / exp_i)
 99                symbol = expand(symbol)
100                self.symbol = symbol
101                self.p_func = lambdify((x, xi_internal), symbol, 'numpy')
102            else:
103                raise ValueError("mode must be 'auto' or 'symbol'")
104
105        elif self.dim == 2:
106            x, y = vars_x
107            xi_internal, eta_internal = symbols('xi eta', real=True)
108            expr = expr.subs(symbols('xi', real=True), xi_internal)
109            expr = expr.subs(symbols('eta', real=True), eta_internal)
110            self.fft = partial(fft2, workers=FFT_WORKERS)
111            self.ifft = partial(ifft2, workers=FFT_WORKERS)
112
113            if mode == 'symbol':
114                self.symbol = expr
115                self.p_func = lambdify((x, y, xi_internal, eta_internal), expr, 'numpy')
116            elif mode == 'auto':
117                if var_u is None:
118                    raise ValueError("var_u must be provided in mode='auto'")
119                exp_i = exp(I * (x * xi_internal + y * eta_internal))
120                P_ei = expr.subs(var_u, exp_i)
121                symbol = simplify(P_ei / exp_i)
122                symbol = expand(symbol)
123                self.symbol = symbol
124                self.p_func = lambdify((x, y, xi_internal, eta_internal), symbol, 'numpy')
125            else:
126                raise ValueError("mode must be 'auto' or 'symbol'")
127
128        else:
129            raise NotImplementedError("Only 1D and 2D supported")
130
131        if mode == 'auto':
132            print("\nsymbol = ")
133            pprint(self.symbol, num_columns=NUM_COLS)
dim
mode
symbol_cached
expr
vars_x
def evaluate(self, X, Y, KX, KY, cache=True):
135    def evaluate(self, X, Y, KX, KY, cache=True):
136        """
137        Evaluate the pseudo-differential operator's symbol on a grid of spatial and frequency coordinates.
138
139        The method dynamically selects between 1D and 2D evaluation based on the spatial dimension.
140        If caching is enabled and a cached symbol exists, it returns the cached result to avoid recomputation.
141
142        Parameters
143        ----------
144        X, Y : ndarray
145            Spatial grid coordinates. In 1D, Y is ignored.
146        KX, KY : ndarray
147            Frequency grid coordinates. In 1D, KY is ignored.
148        cache : bool, default=True
149            If True, stores the computed symbol for reuse in subsequent calls to avoid redundant computation.
150
151        Returns
152        -------
153        ndarray
154            Evaluated symbol values over the input grid. Shape matches the input spatial/frequency grids.
155
156        Raises
157        ------
158        NotImplementedError
159            If the spatial dimension is not 1D or 2D.
160        """
161        if cache and self.symbol_cached is not None:
162            return self.symbol_cached
163
164        if self.dim == 1:
165            symbol = self.p_func(X, KX)
166        elif self.dim == 2:
167            symbol = self.p_func(X, Y, KX, KY)
168
169        if cache:
170            self.symbol_cached = symbol
171
172        return symbol

Evaluate the pseudo-differential operator's symbol on a grid of spatial and frequency coordinates.

The method dynamically selects between 1D and 2D evaluation based on the spatial dimension. If caching is enabled and a cached symbol exists, it returns the cached result to avoid recomputation.

Parameters

X, Y : ndarray Spatial grid coordinates. In 1D, Y is ignored. KX, KY : ndarray Frequency grid coordinates. In 1D, KY is ignored. cache : bool, default=True If True, stores the computed symbol for reuse in subsequent calls to avoid redundant computation.

Returns

ndarray Evaluated symbol values over the input grid. Shape matches the input spatial/frequency grids.

Raises

NotImplementedError If the spatial dimension is not 1D or 2D.

def clear_cache(self):
174    def clear_cache(self):
175        """
176        Clear cached symbol evaluations.
177        """        
178        self.symbol_cached = None

Clear cached symbol evaluations.

def apply( self, u, x_grid, kx, boundary_condition='periodic', y_grid=None, ky=None, dealiasing_mask=None, freq_window='gaussian', clamp=1000000.0, space_window=False):
180    def apply(self, u, x_grid, kx, boundary_condition='periodic', 
181              y_grid=None, ky=None, dealiasing_mask=None,
182              freq_window='gaussian', clamp=1e6, space_window=False):
183        """
184        Apply the pseudo-differential operator to the input field u.
185    
186        This method dispatches the application of the pseudo-differential operator based on:
187        
188        - Whether the symbol is spatially dependent (x/y)
189        - The boundary condition in use (periodic or dirichlet)
190    
191        Supported operations:
192        
193        - Constant-coefficient symbols: applied via Fourier multiplication.
194        - Spatially varying symbols: applied via Kohn–Nirenberg quantization.
195        - Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.
196    
197        Dispatch Logic:\n
198        if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]\n
199        elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)\n
200        elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)\n
201        
202        Parameters
203        ----------
204        u : ndarray
205            Function to which the operator is applied
206        x_grid : ndarray
207            Spatial grid in x direction
208        kx : ndarray
209            Frequency grid in x direction
210        boundary_condition : str
211            'periodic' or 'dirichlet'
212        y_grid : ndarray, optional
213            Spatial grid in y direction (for 2D)
214        ky : ndarray, optional
215            Frequency grid in y direction (for 2D)
216        dealiasing_mask : ndarray, optional
217            Dealiasing mask
218        freq_window : str
219            Frequency windowing ('gaussian' or 'hann')
220        clamp : float
221            Clamp symbol values to [-clamp, clamp]
222        space_window : bool
223            Apply spatial windowing
224            
225        Returns
226        -------
227        ndarray
228            Result of applying the operator
229        """
230        # Check if symbol depends on spatial variables
231        is_spatial = self._is_spatial_dependent()
232        
233        # Case 1: Constant symbol with periodic BC (fast path)
234        if not is_spatial and boundary_condition == 'periodic':
235            return self._apply_constant_fft(u, x_grid, kx, y_grid, ky, dealiasing_mask)
236        
237        # Case 2: Spatial symbol with periodic BC
238        elif boundary_condition == 'periodic':
239            symbol_func = self._get_symbol_func()
240            return kohn_nirenberg_fft(
241                u_vals=u,
242                symbol_func=symbol_func,
243                x_grid=x_grid,
244                kx=kx,
245                fft_func=self.fft,
246                ifft_func=self.ifft,
247                dim=self.dim,
248                y_grid=y_grid,
249                ky=ky,
250                freq_window=freq_window,
251                clamp=clamp,
252                space_window=space_window
253            )
254        
255        # Case 3: Dirichlet BC (non-periodic)
256        elif boundary_condition == 'dirichlet':
257            symbol_func = self._get_symbol_func()
258            
259            if self.dim == 1:
260                return kohn_nirenberg_nonperiodic(
261                    u_vals=u,
262                    x_grid=x_grid,
263                    xi_grid=kx,
264                    symbol_func=symbol_func,
265                    freq_window=freq_window,
266                    clamp=clamp,
267                    space_window=space_window
268                )
269            elif self.dim == 2:
270                return kohn_nirenberg_nonperiodic(
271                    u_vals=u,
272                    x_grid=(x_grid, y_grid),
273                    xi_grid=(kx, ky),
274                    symbol_func=symbol_func,
275                    freq_window=freq_window,
276                    clamp=clamp,
277                    space_window=space_window
278                )
279        
280        else:
281            raise ValueError(f"Invalid boundary condition '{boundary_condition}'")

Apply the pseudo-differential operator to the input field u.

This method dispatches the application of the pseudo-differential operator based on:

  • Whether the symbol is spatially dependent (x/y)
  • The boundary condition in use (periodic or dirichlet)

Supported operations:

  • Constant-coefficient symbols: applied via Fourier multiplication.
  • Spatially varying symbols: applied via Kohn–Nirenberg quantization.
  • Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.

Dispatch Logic:

if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]

elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)

elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)

Parameters

u : ndarray Function to which the operator is applied x_grid : ndarray Spatial grid in x direction kx : ndarray Frequency grid in x direction boundary_condition : str 'periodic' or 'dirichlet' y_grid : ndarray, optional Spatial grid in y direction (for 2D) ky : ndarray, optional Frequency grid in y direction (for 2D) dealiasing_mask : ndarray, optional Dealiasing mask freq_window : str Frequency windowing ('gaussian' or 'hann') clamp : float Clamp symbol values to [-clamp, clamp] space_window : bool Apply spatial windowing

Returns

ndarray Result of applying the operator

def principal_symbol(self, order=1):
379    def principal_symbol(self, order=1):
380        """
381        Compute the leading homogeneous component of the pseudo-differential symbol.
382
383        This method extracts the principal part of the symbol, which is the dominant 
384        term under high-frequency asymptotics (|ξ| → ∞). The expansion is performed 
385        in polar coordinates for 2D symbols to maintain rotational symmetry, then 
386        converted back to Cartesian form.
387
388        Parameters
389        ----------
390        order : int
391            Order of the asymptotic expansion in powers of 1/ρ, where ρ = |ξ| in 1D 
392            or ρ = sqrt(ξ² + η²) in 2D. Only the leading-order term is returned.
393
394        Returns
395        -------
396        sympy.Expr
397            The principal symbol component, homogeneous of degree `m - order`, where 
398            `m` is the original symbol's order.
399
400        Notes:
401        - In 1D, uses direct series expansion in ξ.
402        - In 2D, expands in radial variable ρ while preserving angular dependence.
403        - Useful for microlocal analysis and constructing parametrices.
404        """
405
406        p = self.symbol
407        if self.dim == 1:
408            xi = symbols('xi', real=True, positive=True)
409            return simplify(series(p, xi, oo, n=order).removeO())
410        elif self.dim == 2:
411            xi, eta = symbols('xi eta', real=True, positive=True)
412            # Homogeneous radial expansion: we set (ξ, η) = ρ (cosθ, sinθ)
413            rho, theta = symbols('rho theta', real=True, positive=True)
414            p_rho = p.subs({xi: rho * cos(theta), eta: rho * sin(theta)})
415            expansion = series(p_rho, rho, oo, n=order).removeO()
416            # Revert back to (ξ, η)
417            expansion_cart = expansion.subs({rho: sqrt(xi**2 + eta**2),
418                                             cos(theta): xi / sqrt(xi**2 + eta**2),
419                                             sin(theta): eta / sqrt(xi**2 + eta**2)})
420            return simplify(powdenest(expansion_cart, force=True))

Compute the leading homogeneous component of the pseudo-differential symbol.

This method extracts the principal part of the symbol, which is the dominant term under high-frequency asymptotics (|ξ| → ∞). The expansion is performed in polar coordinates for 2D symbols to maintain rotational symmetry, then converted back to Cartesian form.

Parameters

order : int Order of the asymptotic expansion in powers of 1/ρ, where ρ = |ξ| in 1D or ρ = sqrt(ξ² + η²) in 2D. Only the leading-order term is returned.

Returns

sympy.Expr The principal symbol component, homogeneous of degree m - order, where m is the original symbol's order.

Notes:

  • In 1D, uses direct series expansion in ξ.
  • In 2D, expands in radial variable ρ while preserving angular dependence.
  • Useful for microlocal analysis and constructing parametrices.
def is_homogeneous(self, tol=1e-10):
422    def is_homogeneous(self, tol=1e-10):
423        """
424        Check whether the symbol is homogeneous in the frequency variables.
425    
426        Returns
427        -------
428        (bool, Rational or float or None)
429            Tuple (is_homogeneous, degree) where:
430            - is_homogeneous: True if the symbol satisfies p(λξ, λη) = λ^m * p(ξ, η)
431            - degree: the detected degree m if homogeneous, or None
432        """
433        from sympy import symbols, simplify, expand, Eq
434        from sympy.abc import l
435    
436        if self.dim == 1:
437            xi = symbols('xi', real=True, positive=True)
438            l = symbols('l', real=True, positive=True)
439            p = self.symbol
440            p_scaled = p.subs(xi, l * xi)
441            ratio = simplify(p_scaled / p)
442            if ratio.has(xi):
443                return False, None
444            try:
445                deg = simplify(ratio).as_base_exp()[1]
446                return True, deg
447            except Exception:
448                return False, None
449    
450        elif self.dim == 2:
451            xi, eta = symbols('xi eta', real=True, positive=True)
452            l = symbols('l', real=True, positive=True)
453            p = self.symbol
454            p_scaled = p.subs({xi: l * xi, eta: l * eta})
455            ratio = simplify(p_scaled / p)
456            # If ratio == l**m with no (xi, eta) left, it's homogeneous
457            if ratio.has(xi, eta):
458                return False, None
459            try:
460                base, exp = ratio.as_base_exp()
461                if base == l:
462                    return True, exp
463            except Exception:
464                pass
465            return False, None

Check whether the symbol is homogeneous in the frequency variables.

Returns

(bool, Rational or float or None) Tuple (is_homogeneous, degree) where: - is_homogeneous: True if the symbol satisfies p(λξ, λη) = λ^m * p(ξ, η) - degree: the detected degree m if homogeneous, or None

def symbol_order(self, max_order=10, tol=0.001):
467    def symbol_order(self, max_order=10, tol=1e-3):
468        """
469        Estimate the homogeneity order of the pseudo-differential symbol in high-frequency asymptotics.
470    
471        This method attempts to determine the leading-order behavior of the symbol p(x, ξ) or p(x, y, ξ, η)
472        as |ξ| → ∞ (in 1D) or |(ξ, η)| → ∞ (in 2D). The returned value represents the asymptotic growth or decay rate,
473        which is essential for understanding the regularity and mapping properties of the corresponding operator.
474    
475        The function uses symbolic preprocessing to ensure proper factorization of frequency variables,
476        especially in sqrt and power expressions, to avoid erroneous order detection (e.g., due to hidden scaling).
477    
478        Parameters
479        ----------
480        max_order : int, optional
481            Maximum number of terms to consider in the series expansion. Default is 10.
482        tol : float, optional
483            Tolerance threshold for evaluating the coefficient magnitude. If the coefficient is too small,
484            the detected order may be discarded. Default is 1e-3.
485    
486        Returns
487        -------
488        float or None
489            - If the symbol is homogeneous, returns its exact homogeneity degree as a float.
490            - Otherwise, estimates the dominant asymptotic order from leading terms in the expansion.
491            - Returns None if no valid order could be determined.
492    
493        Notes
494        -----
495        - In 1D:
496            Two strategies are used:
497                1. Expand directly in xi at infinity.
498                2. Substitute xi = 1/z and expand around z = 0.
499    
500        - In 2D:
501            - Transform the symbol into polar coordinates: (xi, eta) = rho*(cos(theta), sin(theta)).
502            - Expand in rho at infinity, then extract the leading term's power.
503            - An alternative substitution using 1/z is also tried if the first method fails.
504    
505        - Preprocessing steps:
506            - Sqrt expressions involving frequencies are rewritten to isolate the leading variable.
507            - Power expressions are factored explicitly to ensure correct symbolic scaling.
508    
509        - If the symbol is not homogeneous, a warning is issued, and the result should be interpreted with care.
510        
511        - For non-homogeneous symbols, only the principal asymptotic term is considered.
512    
513        Raises
514        ------
515        NotImplementedError
516            If the spatial dimension is neither 1 nor 2.
517        """
518        from sympy import (
519            symbols, series, simplify, sqrt, cos, sin, oo, powdenest, radsimp,
520            expand, expand_power_base
521        )
522    
523        def preprocess_sqrt(expr, freq):
524            return expr.replace(
525                lambda e: e.func == sqrt and freq in e.free_symbols,
526                lambda e: freq * sqrt(1 + (e.args[0] - freq**2) / freq**2)
527            )
528    
529        def preprocess_power(expr, freq):
530            return expr.replace(
531                lambda e: e.is_Pow and freq in e.free_symbols,
532                lambda e: freq**e.exp * (1 + e.base / freq**e.base.as_powers_dict().get(freq, 0))**e.exp
533            )
534    
535        def validate_order(power, coeff, vars_x, tol):
536            if power is None:
537                return None
538            if any(v in coeff.free_symbols for v in vars_x):
539                print("⚠️ Coefficient depends on spatial variables; ignoring")
540                return None
541            try:
542                coeff_val = abs(float(coeff.evalf()))
543                if coeff_val < tol:
544                    print(f"⚠️ Coefficient too small ({coeff_val:.2e} < {tol})")
545                    return None
546            except Exception as e:
547                print(f"⚠️ Coefficient evaluation failed: {e}")
548                return None
549            return int(power) if power == int(power) else float(power)
550    
551        # Homogeneity check
552        is_homog, degree = self.is_homogeneous()
553        if is_homog:
554            return float(degree)
555        else:
556            print("⚠️ The symbol is not homogeneous. The asymptotic order is not well defined.")
557    
558        if self.dim == 1:
559            x = self.vars_x[0]
560            xi = symbols('xi', real=True, positive=True)
561    
562            try:
563                print("1D symbol_order - method 1")
564                expr = preprocess_sqrt(self.symbol, xi)
565                s = series(expr, xi, oo, n=max_order).removeO()
566                lead = simplify(powdenest(s.as_leading_term(xi), force=True))
567                power = lead.as_powers_dict().get(xi, None)
568                coeff = lead / xi**power if power is not None else 0
569                print("lead =", lead)
570                print("power =", power)
571                print("coeff =", coeff)
572                order = validate_order(power, coeff, [x], tol)
573                if order is not None:
574                    return order
575            except Exception:
576                pass
577    
578            try:
579                print("1D symbol_order - method 2")
580                z = symbols('z', real=True, positive=True)
581                expr_z = preprocess_sqrt(self.symbol.subs(xi, 1/z), 1/z)
582                s = series(expr_z, z, 0, n=max_order).removeO()
583                lead = simplify(powdenest(s.as_leading_term(z), force=True))
584                power = lead.as_powers_dict().get(z, None)
585                coeff = lead / z**power if power is not None else 0
586                print("lead =", lead)
587                print("power =", power)
588                print("coeff =", coeff)
589                order = validate_order(power, coeff, [x], tol)
590                if order is not None:
591                    return -order
592            except Exception as e:
593                print(f"⚠️ fallback z failed: {e}")
594            return None
595    
596        elif self.dim == 2:
597            x, y = self.vars_x
598            xi, eta = symbols('xi eta', real=True, positive=True)
599            rho, theta = symbols('rho theta', real=True, positive=True)
600    
601            try:
602                print("2D symbol_order - method 1")
603                p_rho = self.symbol.subs({xi: rho * cos(theta), eta: rho * sin(theta)})
604                p_rho = preprocess_power(preprocess_sqrt(p_rho, rho), rho)
605                s = series(simplify(p_rho), rho, oo, n=max_order).removeO()
606                lead = radsimp(simplify(powdenest(s.as_leading_term(rho), force=True)))
607                power = lead.as_powers_dict().get(rho, None)
608                coeff = lead / rho**power if power is not None else 0
609                print("lead =", lead)
610                print("power =", power)
611                print("coeff =", coeff)
612                order = validate_order(power, coeff, [x, y], tol)
613                if order is not None:
614                    return order
615            except Exception as e:
616                print(f"⚠️ polar expansion failed: {e}")
617    
618            try:
619                print("2D symbol_order - method 2")
620                z = symbols('z', real=True, positive=True)
621                xi_eta = {xi: (1/z) * cos(theta), eta: (1/z) * sin(theta)}
622                p_rho = preprocess_sqrt(self.symbol.subs(xi_eta), 1/z)
623                s = series(simplify(p_rho), z, 0, n=max_order).removeO()
624                lead = radsimp(simplify(powdenest(s.as_leading_term(z), force=True)))
625                power = lead.as_powers_dict().get(z, None)
626                coeff = lead / z**power if power is not None else 0
627                print("lead =", lead)
628                print("power =", power)
629                print("coeff =", coeff)
630                order = validate_order(power, coeff, [x, y], tol)
631                if order is not None:
632                    return -order
633            except Exception as e:
634                print(f"⚠️ fallback z (2D) failed: {e}")
635            return None
636    
637        else:
638            raise NotImplementedError("Only 1D and 2D supported.")

Estimate the homogeneity order of the pseudo-differential symbol in high-frequency asymptotics.

This method attempts to determine the leading-order behavior of the symbol p(x, ξ) or p(x, y, ξ, η) as |ξ| → ∞ (in 1D) or |(ξ, η)| → ∞ (in 2D). The returned value represents the asymptotic growth or decay rate, which is essential for understanding the regularity and mapping properties of the corresponding operator.

The function uses symbolic preprocessing to ensure proper factorization of frequency variables, especially in sqrt and power expressions, to avoid erroneous order detection (e.g., due to hidden scaling).

Parameters

max_order : int, optional Maximum number of terms to consider in the series expansion. Default is 10. tol : float, optional Tolerance threshold for evaluating the coefficient magnitude. If the coefficient is too small, the detected order may be discarded. Default is 1e-3.

Returns

float or None - If the symbol is homogeneous, returns its exact homogeneity degree as a float. - Otherwise, estimates the dominant asymptotic order from leading terms in the expansion. - Returns None if no valid order could be determined.

Notes

  • In 1D: Two strategies are used: 1. Expand directly in xi at infinity. 2. Substitute xi = 1/z and expand around z = 0.

  • In 2D:

    • Transform the symbol into polar coordinates: (xi, eta) = rho*(cos(theta), sin(theta)).
    • Expand in rho at infinity, then extract the leading term's power.
    • An alternative substitution using 1/z is also tried if the first method fails.
  • Preprocessing steps:

    • Sqrt expressions involving frequencies are rewritten to isolate the leading variable.
    • Power expressions are factored explicitly to ensure correct symbolic scaling.
  • If the symbol is not homogeneous, a warning is issued, and the result should be interpreted with care.

  • For non-homogeneous symbols, only the principal asymptotic term is considered.

Raises

NotImplementedError If the spatial dimension is neither 1 nor 2.

def asymptotic_expansion(self, order=3):
641    def asymptotic_expansion(self, order=3):
642        """
643        Compute the asymptotic expansion of the symbol as |ξ| → ∞ (high-frequency regime).
644    
645        This method expands the pseudo-differential symbol in inverse powers of the 
646        frequency variable(s), either in 1D or 2D. It handles both polynomial and 
647        exponential symbols by performing a series expansion in 1/|ξ| up to the specified order.
648    
649        The expansion is performed directly in Cartesian coordinates for 1D symbols.
650        For 2D symbols, the method uses polar coordinates (ρ, θ) to perform the expansion 
651        at infinity in ρ, then converts the result back to Cartesian coordinates.
652    
653        Parameters
654        ----------
655        order : int, optional
656            Maximum order of the asymptotic expansion. Default is 3.
657    
658        Returns
659        -------
660        sympy.Expr
661            The asymptotic expansion of the symbol up to the given order, expressed in Cartesian coordinates.
662            If expansion fails, returns the original unexpanded symbol.
663    
664        Notes:
665        - In 1D: expansion is performed directly in terms of ξ.
666        - In 2D: the symbol is first rewritten in polar coordinates (ρ,θ), expanded asymptotically 
667          in ρ → ∞, then converted back to Cartesian coordinates (ξ,η).
668        - Handles special case when the symbol is an exponential function by expanding its argument.
669        - Symbolic normalization is applied early (via `simplify`) for 2D expressions to improve convergence.
670        - Robust to failures: catches exceptions and issues warnings instead of raising errors.
671        - Final expression is simplified using `powdenest` and `expand` for improved readability.
672        """
673        p = self.symbol
674    
675        if self.dim == 1:
676            xi = symbols('xi', real=True, positive=True)
677    
678            try:
679                # Case: exponential function
680                if p.func == exp and len(p.args) == 1:
681                    arg = p.args[0]
682                    arg_series = series(arg, xi, oo, n=order).removeO()
683                    expanded = series(exp(expand(arg_series)), xi, oo, n=order).removeO()
684                    return simplify(powdenest(expanded, force=True))
685                else:
686                    expanded = series(p, xi, oo, n=order).removeO()
687                    return simplify(powdenest(expanded, force=True))
688    
689            except Exception as e:
690                print(f"Warning: 1D expansion failed: {e}")
691                return p
692    
693        elif self.dim == 2:
694            xi, eta = symbols('xi eta', real=True, positive=True)
695            rho, theta = symbols('rho theta', real=True, positive=True)
696    
697            # Normalize before substitution
698            p = simplify(p)
699    
700            # Substitute polar coordinates
701            p_polar = p.subs({
702                xi: rho * cos(theta),
703                eta: rho * sin(theta)
704            })
705    
706            try:
707                # Handle exponentials
708                if p_polar.func == exp and len(p_polar.args) == 1:
709                    arg = p_polar.args[0]
710                    arg_series = series(arg, rho, oo, n=order).removeO()
711                    expanded = series(exp(expand(arg_series)), rho, oo, n=order).removeO()
712                else:
713                    expanded = series(p_polar, rho, oo, n=order).removeO()
714    
715                # Convert back to Cartesian
716                norm = sqrt(xi**2 + eta**2)
717                expansion_cart = expanded.subs({
718                    rho: norm,
719                    cos(theta): xi / norm,
720                    sin(theta): eta / norm
721                })
722    
723                # Final simplifications
724                result = simplify(powdenest(expansion_cart, force=True))
725                result = expand(result)
726                return result
727    
728            except Exception as e:
729                print(f"Warning: 2D expansion failed: {e}")
730                return p  

Compute the asymptotic expansion of the symbol as |ξ| → ∞ (high-frequency regime).

This method expands the pseudo-differential symbol in inverse powers of the frequency variable(s), either in 1D or 2D. It handles both polynomial and exponential symbols by performing a series expansion in 1/|ξ| up to the specified order.

The expansion is performed directly in Cartesian coordinates for 1D symbols. For 2D symbols, the method uses polar coordinates (ρ, θ) to perform the expansion at infinity in ρ, then converts the result back to Cartesian coordinates.

Parameters

order : int, optional Maximum order of the asymptotic expansion. Default is 3.

Returns

sympy.Expr The asymptotic expansion of the symbol up to the given order, expressed in Cartesian coordinates. If expansion fails, returns the original unexpanded symbol.

Notes:

  • In 1D: expansion is performed directly in terms of ξ.
  • In 2D: the symbol is first rewritten in polar coordinates (ρ,θ), expanded asymptotically in ρ → ∞, then converted back to Cartesian coordinates (ξ,η).
  • Handles special case when the symbol is an exponential function by expanding its argument.
  • Symbolic normalization is applied early (via simplify) for 2D expressions to improve convergence.
  • Robust to failures: catches exceptions and issues warnings instead of raising errors.
  • Final expression is simplified using powdenest and expand for improved readability.
def compose_asymptotic(self, other, order=1, mode='kn', sign_convention=None):
732    def compose_asymptotic(self, other, order=1, mode='kn', sign_convention=None):
733        """
734        Compose two pseudo-differential operators using an asymptotic expansion
735        in the chosen quantization scheme (Kohn–Nirenberg or Weyl).
736    
737        Parameters
738        ----------
739        other : PseudoDifferentialOperator
740            The operator to compose with this one.
741        order : int, default=1
742            Maximum order of the asymptotic expansion.
743        mode : {'kn', 'weyl'}, default='kn'
744            Quantization mode:
745            - 'kn' : Kohn–Nirenberg quantization (left-quantized)
746            - 'weyl' : Weyl symmetric quantization
747        sign_convention : {'standard', 'inverse'}, optional
748            Controls the phase factor convention for the KN case:
749            - 'standard' → (i)^(-n), gives [x, ξ] = +i (physics convention)
750            - 'inverse' → (i)^(+n), gives [x, ξ] = -i (mathematical adjoint convention)
751            If None, defaults to 'standard'.
752    
753        Returns
754        -------
755        sympy.Expr
756            Symbolic expression for the composed symbol up to the given order.
757    
758        Notes
759        -----
760        - In 1D (Kohn–Nirenberg):
761            (p ∘ q)(x, ξ) ~ Σₙ (1/n!) (i sgn)^n ∂_ξⁿ p(x, ξ) ∂_xⁿ q(x, ξ)
762        - In 1D (Weyl):
763            (p # q)(x, ξ) = exp[(i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q)] p(x, ξ) q(x, ξ)
764            truncated at given order.
765    
766        Examples
767        --------
768        X = a*x, Y = b*ξ
769        X_op.compose_asymptotic(Y_op, order=3, mode='weyl')
770        """
771    
772        from sympy import diff, factorial, simplify, symbols
773    
774        assert self.dim == other.dim, "Operator dimensions must match"
775        p, q = self.symbol, other.symbol
776    
777        # Default sign convention
778        if sign_convention is None:
779            sign_convention = 'standard'
780        sign = -1 if sign_convention == 'standard' else +1
781    
782        # --- 1D case ---
783        if self.dim == 1:
784            x = self.vars_x[0]
785            xi = symbols('xi', real=True)
786            result = 0
787    
788            if mode == 'kn':  # Kohn–Nirenberg
789                for n in range(order + 1):
790                    term = (1 / factorial(n)) * diff(p, xi, n) * diff(q, x, n) * (1j) ** (sign * n)
791                    result += term
792    
793            elif mode == 'weyl':  # Weyl symmetric composition
794                # Weyl star product: exp((i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q))
795                result = 0
796                for n in range(order + 1):
797                    for k in range(n + 1):
798                        # k derivatives acting as (∂_ξ^k p)(∂_x^(n−k) q)
799                        coeff = (1 / (factorial(k) * factorial(n - k))) * ((1j / 2) ** n) * ((-1) ** (n - k))
800                        term = coeff * diff(p, xi, k, x, n - k, evaluate=True) * diff(q, x, k, xi, n - k, evaluate=True)
801                        result += term
802    
803            else:
804                raise ValueError("mode must be either 'kn' or 'weyl'")
805    
806            return simplify(result)
807    
808        # --- 2D case ---
809        elif self.dim == 2:
810            x, y = self.vars_x
811            xi, eta = symbols('xi eta', real=True)
812            result = 0
813    
814            if mode == 'kn':
815                for n in range(order + 1):
816                    for i in range(n + 1):
817                        j = n - i
818                        term = (1 / (factorial(i) * factorial(j))) * \
819                               diff(p, xi, i, eta, j) * diff(q, x, i, y, j) * (1j) ** (sign * n)
820                        result += term
821    
822            elif mode == 'weyl':
823                for n in range(order + 1):
824                    for i in range(n + 1):
825                        j = n - i
826                        coeff = (1 / (factorial(i) * factorial(j))) * ((1j / 2) ** n) * ((-1) ** (n - i))
827                        term = coeff * diff(p, xi, i, eta, j, x, 0, y, 0) * diff(q, x, i, y, j, xi, 0, eta, 0)
828                        result += term
829            else:
830                raise ValueError("mode must be either 'kn' or 'weyl'")
831    
832            return simplify(result)
833    
834        else:
835            raise NotImplementedError("Only 1D and 2D cases are implemented")

Compose two pseudo-differential operators using an asymptotic expansion in the chosen quantization scheme (Kohn–Nirenberg or Weyl).

Parameters

other : PseudoDifferentialOperator The operator to compose with this one. order : int, default=1 Maximum order of the asymptotic expansion. mode : {'kn', 'weyl'}, default='kn' Quantization mode: - 'kn' : Kohn–Nirenberg quantization (left-quantized) - 'weyl' : Weyl symmetric quantization sign_convention : {'standard', 'inverse'}, optional Controls the phase factor convention for the KN case: - 'standard' → (i)^(-n), gives [x, ξ] = +i (physics convention) - 'inverse' → (i)^(+n), gives [x, ξ] = -i (mathematical adjoint convention) If None, defaults to 'standard'.

Returns

sympy.Expr Symbolic expression for the composed symbol up to the given order.

Notes

  • In 1D (Kohn–Nirenberg): (p ∘ q)(x, ξ) ~ Σₙ (1/n!) (i sgn)^n ∂_ξⁿ p(x, ξ) ∂_xⁿ q(x, ξ)
  • In 1D (Weyl): (p # q)(x, ξ) = exp[(i/2)(∂_ξ^p ∂_x^q - ∂_x^p ∂_ξ^q)] p(x, ξ) q(x, ξ) truncated at given order.

Examples

X = ax, Y = bξ X_op.compose_asymptotic(Y_op, order=3, mode='weyl')

def commutator_symbolic(self, other, order=1, mode='kn', sign_convention=None):
837    def commutator_symbolic(self, other, order=1, mode='kn', sign_convention=None):
838        """
839        Compute the symbolic commutator [A, B] = A∘B − B∘A of two pseudo-differential operators
840        using formal asymptotic expansion of their composition symbols.
841    
842        This method computes the asymptotic expansion of the commutator's symbol up to a given 
843        order, based on the symbolic calculus of pseudo-differential operators in the 
844        Kohn–Nirenberg quantization. The result is a purely symbolic sympy expression that 
845        captures the leading-order noncommutativity of the operators.
846    
847        Parameters
848        ----------
849        other : PseudoDifferentialOperator
850            The pseudo-differential operator B to commute with this operator A.
851        order : int, default=1
852            Maximum order of the asymptotic expansion. 
853            - order=1 yields the leading term proportional to the Poisson bracket {p, q}.
854            - Higher orders include correction terms involving higher mixed derivatives.
855    
856        Returns
857        -------
858        sympy.Expr
859            Symbolic expression for the asymptotic expansion of the commutator symbol 
860            σ([A,B]) = σ(A∘B − B∘A).
861    
862        """
863        assert self.dim == other.dim, "Operator dimensions must match"
864        p, q = self.symbol, other.symbol
865    
866        pq = self.compose_asymptotic(other, order=order, mode=mode, sign_convention=sign_convention)
867        qp = other.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
868        
869        comm_symbol = simplify(pq-qp)
870
871        return comm_symbol

Compute the symbolic commutator [A, B] = A∘B − B∘A of two pseudo-differential operators using formal asymptotic expansion of their composition symbols.

This method computes the asymptotic expansion of the commutator's symbol up to a given order, based on the symbolic calculus of pseudo-differential operators in the Kohn–Nirenberg quantization. The result is a purely symbolic sympy expression that captures the leading-order noncommutativity of the operators.

Parameters

other : PseudoDifferentialOperator The pseudo-differential operator B to commute with this operator A. order : int, default=1 Maximum order of the asymptotic expansion. - order=1 yields the leading term proportional to the Poisson bracket {p, q}. - Higher orders include correction terms involving higher mixed derivatives.

Returns

sympy.Expr Symbolic expression for the asymptotic expansion of the commutator symbol σ([A,B]) = σ(A∘B − B∘A).

def right_inverse_asymptotic(self, order=1):
873    def right_inverse_asymptotic(self, order=1):
874        """
875        Construct a formal right inverse R of the pseudo-differential operator P such that 
876        the composition P ∘ R equals the identity plus a smoothing operator of order -order.
877    
878        This method computes an asymptotic expansion for the right inverse using recursive 
879        corrections based on derivatives of the symbol p(x, ξ) and lower-order terms of R.
880    
881        Parameters
882        ----------
883        order : int
884            Number of terms to include in the asymptotic expansion. Higher values improve 
885            approximation at the cost of complexity and computational effort.
886    
887        Returns
888        -------
889        sympy.Expr
890            The symbolic expression representing the formal right inverse R(x, ξ), which satisfies:
891            P ∘ R = Id + O(⟨ξ⟩^{-order}), where ⟨ξ⟩ = (1 + |ξ|²)^{1/2}.
892    
893        Notes
894        -----
895        - In 1D: The recursion involves spatial derivatives of R and derivatives of p with respect to ξ.
896        - In 2D: The multi-index generalization is used with mixed derivatives in ξ and η.
897        - The construction relies on the non-vanishing of the principal symbol p to ensure invertibility.
898        - Each term in the expansion corresponds to higher-order corrections involving commutators 
899          between the operator P and the current approximation of R.
900        """
901        p = self.symbol
902        if self.dim == 1:
903            x = self.vars_x[0]
904            xi = symbols('xi', real=True)
905            r = 1 / p.subs(xi, xi)  # r0
906            R = r
907            for n in range(1, order + 1):
908                term = 0
909                for k in range(1, n + 1):
910                    coeff = (1j)**(-k) / factorial(k)
911                    inner = diff(p, xi, k) * diff(R, x, k)
912                    term += coeff * inner
913                R = R - r * term
914        elif self.dim == 2:
915            x, y = self.vars_x
916            xi, eta = symbols('xi eta', real=True)
917            r = 1 / p.subs({xi: xi, eta: eta})
918            R = r
919            for n in range(1, order + 1):
920                term = 0
921                for k1 in range(n + 1):
922                    for k2 in range(n + 1 - k1):
923                        if k1 + k2 == 0: continue
924                        coeff = (1j)**(-(k1 + k2)) / (factorial(k1) * factorial(k2))
925                        dp = diff(p, xi, k1, eta, k2)
926                        dR = diff(R, x, k1, y, k2)
927                        term += coeff * dp * dR
928                R = R - r * term
929        return R

Construct a formal right inverse R of the pseudo-differential operator P such that the composition P ∘ R equals the identity plus a smoothing operator of order -order.

This method computes an asymptotic expansion for the right inverse using recursive corrections based on derivatives of the symbol p(x, ξ) and lower-order terms of R.

Parameters

order : int Number of terms to include in the asymptotic expansion. Higher values improve approximation at the cost of complexity and computational effort.

Returns

sympy.Expr The symbolic expression representing the formal right inverse R(x, ξ), which satisfies: P ∘ R = Id + O(⟨ξ⟩^{-order}), where ⟨ξ⟩ = (1 + |ξ|²)^{1/2}.

Notes

  • In 1D: The recursion involves spatial derivatives of R and derivatives of p with respect to ξ.
  • In 2D: The multi-index generalization is used with mixed derivatives in ξ and η.
  • The construction relies on the non-vanishing of the principal symbol p to ensure invertibility.
  • Each term in the expansion corresponds to higher-order corrections involving commutators between the operator P and the current approximation of R.
def left_inverse_asymptotic(self, order=1):
931    def left_inverse_asymptotic(self, order=1):
932        """
933        Construct a formal left inverse L such that the composition L ∘ P equals the identity 
934        operator up to terms of order ξ^{-order}. This expansion is performed asymptotically 
935        at infinity in the frequency variable(s).
936    
937        The left inverse is built iteratively using symbolic differentiation and the 
938        method of asymptotic expansions for pseudo-differential operators. It ensures that:
939        
940            L(P(x,ξ),x,D) ∘ P(x,D) = Id + smoothing operator of order -order
941    
942        Parameters
943        ----------
944        order : int, optional
945            Maximum number of terms in the asymptotic expansion (default is 1). Higher values 
946            yield more accurate inverses at the cost of increased computational complexity.
947    
948        Returns
949        -------
950        sympy.Expr
951            Symbolic expression representing the principal symbol of the formal left inverse 
952            operator L(x,ξ). This expression depends on spatial variables and frequencies, 
953            and includes correction terms up to the specified order.
954    
955        Notes
956        -----
957        - In 1D: Uses recursive application of the Leibniz formula for symbols.
958        - In 2D: Generalizes to multi-indices for mixed derivatives in (x,y) and (ξ,η).
959        - Each term involves combinations of derivatives of the original symbol p(x,ξ) and 
960          previously computed terms of the inverse.
961        - Coefficients include powers of 1j (i) and factorial normalization for derivative terms.
962        """
963        p = self.symbol
964        if self.dim == 1:
965            x = self.vars_x[0]
966            xi = symbols('xi', real=True)
967            l = 1 / p.subs(xi, xi)
968            L = l
969            for n in range(1, order + 1):
970                term = 0
971                for k in range(1, n + 1):
972                    coeff = (1j)**(-k) / factorial(k)
973                    inner = diff(L, xi, k) * diff(p, x, k)
974                    term += coeff * inner
975                L = L - term * l
976        elif self.dim == 2:
977            x, y = self.vars_x
978            xi, eta = symbols('xi eta', real=True)
979            l = 1 / p.subs({xi: xi, eta: eta})
980            L = l
981            for n in range(1, order + 1):
982                term = 0
983                for k1 in range(n + 1):
984                    for k2 in range(n + 1 - k1):
985                        if k1 + k2 == 0: continue
986                        coeff = (1j)**(-(k1 + k2)) / (factorial(k1) * factorial(k2))
987                        dp = diff(p, x, k1, y, k2)
988                        dL = diff(L, xi, k1, eta, k2)
989                        term += coeff * dL * dp
990                L = L - term * l
991        return L

Construct a formal left inverse L such that the composition L ∘ P equals the identity operator up to terms of order ξ^{-order}. This expansion is performed asymptotically at infinity in the frequency variable(s).

The left inverse is built iteratively using symbolic differentiation and the method of asymptotic expansions for pseudo-differential operators. It ensures that:

L(P(x,ξ),x,D) ∘ P(x,D) = Id + smoothing operator of order -order

Parameters

order : int, optional Maximum number of terms in the asymptotic expansion (default is 1). Higher values yield more accurate inverses at the cost of increased computational complexity.

Returns

sympy.Expr Symbolic expression representing the principal symbol of the formal left inverse operator L(x,ξ). This expression depends on spatial variables and frequencies, and includes correction terms up to the specified order.

Notes

  • In 1D: Uses recursive application of the Leibniz formula for symbols.
  • In 2D: Generalizes to multi-indices for mixed derivatives in (x,y) and (ξ,η).
  • Each term involves combinations of derivatives of the original symbol p(x,ξ) and previously computed terms of the inverse.
  • Coefficients include powers of 1j (i) and factorial normalization for derivative terms.
def formal_adjoint(self):
 993    def formal_adjoint(self):
 994        """
 995        Compute the formal adjoint symbol P* of the pseudo-differential operator.
 996
 997        The adjoint is defined such that for any test functions u and v,
 998        ⟨P u, v⟩ = ⟨u, P* v⟩ holds in the distributional sense. This is obtained by 
 999        taking the complex conjugate of the symbol and expanding it asymptotically 
1000        at infinity to ensure proper behavior under integration by parts.
1001
1002        Returns
1003        -------
1004        sympy.Expr
1005            The adjoint symbol P*(x, ξ) in 1D or P*(x, y, ξ, η) in 2D.
1006        
1007        Notes:
1008        - In 1D, the expansion is performed in powers of 1/|ξ|.
1009        - In 2D, the expansion is radial in |ξ| = sqrt(ξ² + η²).
1010        - This method ensures symbolic simplifications for readability and efficiency.
1011        """
1012        p = self.symbol
1013        if self.dim == 1:
1014            x, = self.vars_x
1015            xi = symbols('xi', real=True)
1016            p_star = conjugate(p)
1017            p_star = simplify(series(p_star, xi, oo, n=6).removeO())
1018            return p_star
1019        elif self.dim == 2:
1020            x, y = self.vars_x
1021            xi, eta = symbols('xi eta', real=True)
1022            p_star = conjugate(p)
1023            p_star = simplify(series(p_star, sqrt(xi**2 + eta**2), oo, n=6).removeO())
1024            return p_star

Compute the formal adjoint symbol P* of the pseudo-differential operator.

The adjoint is defined such that for any test functions u and v, ⟨P u, v⟩ = ⟨u, P* v⟩ holds in the distributional sense. This is obtained by taking the complex conjugate of the symbol and expanding it asymptotically at infinity to ensure proper behavior under integration by parts.

Returns

sympy.Expr The adjoint symbol P(x, ξ) in 1D or P(x, y, ξ, η) in 2D.

Notes:

  • In 1D, the expansion is performed in powers of 1/|ξ|.
  • In 2D, the expansion is radial in |ξ| = sqrt(ξ² + η²).
  • This method ensures symbolic simplifications for readability and efficiency.
def exponential_symbol(self, t=1.0, order=1, mode='kn', sign_convention=None):
1026    def exponential_symbol(self, t=1.0, order=1, mode='kn', sign_convention=None):
1027        """
1028        Compute the symbol of exp(tP) using asymptotic expansion methods.
1029        
1030        This method calculates the exponential of a pseudo-differential operator 
1031        using either a direct power series expansion or a Magnus expansion, 
1032        depending on the structure of the symbol. The result is valid up to 
1033        the specified asymptotic order.
1034        
1035        Parameters
1036        ----------
1037        t : float or sympy.Symbol, default=1.0
1038            Time or evolution parameter. Common uses:
1039            - t = -i*τ for Schrödinger evolution: exp(-iτH)
1040            - t = τ for heat/diffusion: exp(τΔ)
1041            - t for general propagators
1042        order : int, default=3
1043            Maximum order of the asymptotic expansion. Higher orders include 
1044            more composition terms, improving accuracy for small t or when 
1045            non-commutativity effects are significant.
1046        
1047        Returns
1048        -------
1049        sympy.Expr
1050            Symbolic expression for the exponential operator symbol, computed 
1051            as an asymptotic series up to the specified order.
1052        
1053        Notes
1054        -----
1055        - For commutative symbols (e.g., pure multiplication operators), the 
1056          exponential is exact: exp(tP) = exp(t*p(x,ξ)).
1057        
1058        - For general non-commutative operators, the method uses the BCH-type 
1059          expansion via iterated composition:
1060          exp(tP) ~ I + tP + (t²/2!)P∘P + (t³/3!)P∘P∘P + ...
1061          
1062        - Each power P^n is computed via compose_asymptotic, which accounts 
1063          for the non-commutativity through derivative terms.
1064        
1065        - The expansion is valid for |t| small enough or when the symbol has 
1066          appropriate decay/growth properties.
1067        
1068        - In quantum mechanics (Schrödinger): U(t) = exp(-itH/ℏ) represents 
1069          the time evolution operator.
1070        
1071        - In parabolic PDEs (heat equation): exp(tΔ) is the heat kernel.
1072
1073        """
1074        if self.dim == 1:
1075            x = self.vars_x[0]
1076            xi = symbols('xi', real=True)
1077            
1078            # Initialize with identity
1079            result = 1
1080            
1081            # First order term: tP
1082            current_power = self.symbol
1083            result += t * current_power
1084            
1085            # Higher order terms: (t^n/n!) P^n computed via composition
1086            for n in range(2, order + 1):
1087                # Compute P^n = P^(n-1) ∘ P via asymptotic composition
1088                # We use a temporary operator for composition
1089                temp_op = PseudoDifferentialOperator(
1090                    current_power, [x], mode='symbol'
1091                )
1092                current_power = temp_op.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
1093                
1094                # Add term (t^n/n!) * P^n
1095                coeff = t**n / factorial(n)
1096                result += coeff * current_power
1097            
1098            return simplify(result)
1099        
1100        elif self.dim == 2:
1101            x, y = self.vars_x
1102            xi, eta = symbols('xi eta', real=True)
1103            
1104            # Initialize with identity
1105            result = 1
1106            
1107            # First order term: tP
1108            current_power = self.symbol
1109            result += t * current_power
1110            
1111            # Higher order terms: (t^n/n!) P^n computed via composition
1112            for n in range(2, order + 1):
1113                # Compute P^n = P^(n-1) ∘ P via asymptotic composition
1114                temp_op = PseudoDifferentialOperator(
1115                    current_power, [x, y], mode='symbol'
1116                )
1117                current_power = temp_op.compose_asymptotic(self, order=order, mode=mode, sign_convention=sign_convention)
1118                
1119                # Add term (t^n/n!) * P^n
1120                coeff = t**n / factorial(n)
1121                result += coeff * current_power
1122            
1123            return simplify(result)
1124        
1125        else:
1126            raise NotImplementedError("Only 1D and 2D operators are supported")

Compute the symbol of exp(tP) using asymptotic expansion methods.

This method calculates the exponential of a pseudo-differential operator using either a direct power series expansion or a Magnus expansion, depending on the structure of the symbol. The result is valid up to the specified asymptotic order.

Parameters

t : float or sympy.Symbol, default=1.0 Time or evolution parameter. Common uses: - t = -i*τ for Schrödinger evolution: exp(-iτH) - t = τ for heat/diffusion: exp(τΔ) - t for general propagators order : int, default=3 Maximum order of the asymptotic expansion. Higher orders include more composition terms, improving accuracy for small t or when non-commutativity effects are significant.

Returns

sympy.Expr Symbolic expression for the exponential operator symbol, computed as an asymptotic series up to the specified order.

Notes

  • For commutative symbols (e.g., pure multiplication operators), the exponential is exact: exp(tP) = exp(t*p(x,ξ)).

  • For general non-commutative operators, the method uses the BCH-type expansion via iterated composition: exp(tP) ~ I + tP + (t²/2!)P∘P + (t³/3!)P∘P∘P + ...

  • Each power P^n is computed via compose_asymptotic, which accounts for the non-commutativity through derivative terms.

  • The expansion is valid for |t| small enough or when the symbol has appropriate decay/growth properties.

  • In quantum mechanics (Schrödinger): U(t) = exp(-itH/ℏ) represents the time evolution operator.

  • In parabolic PDEs (heat equation): exp(tΔ) is the heat kernel.

def trace_formula( self, volume_element=None, numerical=False, x_bounds=None, xi_bounds=None):
1128    def trace_formula(self, volume_element=None, numerical=False, 
1129                      x_bounds=None, xi_bounds=None):
1130        """
1131        Compute the semiclassical trace of the pseudo-differential operator.
1132        
1133        The trace formula relates the quantum trace of an operator to a 
1134        phase-space integral of its symbol, providing a fundamental link 
1135        between classical and quantum mechanics. This implementation supports 
1136        both symbolic and numerical integration.
1137        
1138        Parameters
1139        ----------
1140        volume_element : sympy.Expr, optional
1141            Custom volume element for the phase space integration. If None, 
1142            uses the standard Liouville measure dx dξ/(2π)^d.
1143        numerical : bool, default=False
1144            If True, perform numerical integration over specified bounds.
1145            If False, attempt symbolic integration (may fail for complex symbols).
1146        x_bounds : tuple of tuples, optional
1147            Spatial integration bounds. For 1D: ((x_min, x_max),)
1148            For 2D: ((x_min, x_max), (y_min, y_max))
1149            Required if numerical=True.
1150        xi_bounds : tuple of tuples, optional
1151            Frequency integration bounds. For 1D: ((xi_min, xi_max),)
1152            For 2D: ((xi_min, xi_max), (eta_min, eta_max))
1153            Required if numerical=True.
1154        
1155        Returns
1156        -------
1157        sympy.Expr or float
1158            The trace of the operator. Returns a symbolic expression if 
1159            numerical=False, or a float if numerical=True.
1160        
1161        Notes
1162        -----
1163        - The semiclassical trace formula states:
1164          Tr(P) = (2π)^{-d} ∫∫ p(x,ξ) dx dξ
1165          where d is the spatial dimension and p(x,ξ) is the operator symbol.
1166        
1167        - For 1D: Tr(P) = (1/2π) ∫_{-∞}^{∞} ∫_{-∞}^{∞} p(x,ξ) dx dξ
1168        
1169        - For 2D: Tr(P) = (1/4π²) ∫∫∫∫ p(x,y,ξ,η) dx dy dξ dη
1170        
1171        - This formula is exact for trace-class operators and provides an 
1172          asymptotic approximation for general pseudo-differential operators.
1173        
1174        - Physical interpretation: the trace counts the "number of states" 
1175          weighted by the observable p(x,ξ).
1176        
1177        - For projection operators (χ_Ω with χ² = χ), the trace gives the 
1178          dimension of the range, related to the phase space volume of Ω.
1179        
1180        - The factor (2π)^{-d} comes from the quantum normalization of 
1181          coherent states / Weyl quantization.
1182        """
1183        from sympy import integrate, simplify, lambdify
1184        from scipy.integrate import dblquad, nquad
1185        
1186        p = self.symbol
1187        
1188        if numerical:
1189            if x_bounds is None or xi_bounds is None:
1190                raise ValueError(
1191                    "x_bounds and xi_bounds must be provided for numerical integration"
1192                )
1193        
1194        if self.dim == 1:
1195            x, = self.vars_x
1196            xi = symbols('xi', real=True)
1197            
1198            if volume_element is None:
1199                volume_element = 1 / (2 * pi)
1200            
1201            if numerical:
1202                # Numerical integration
1203                p_func = lambdify((x, xi), p, 'numpy')
1204                (x_min, x_max), = x_bounds
1205                (xi_min, xi_max), = xi_bounds
1206                
1207                def integrand(xi_val, x_val):
1208                    return p_func(x_val, xi_val)
1209                
1210                result, error = dblquad(
1211                    integrand,
1212                    x_min, x_max,
1213                    lambda x: xi_min, lambda x: xi_max
1214                )
1215                
1216                result *= float(volume_element)
1217                print(f"Numerical trace = {result:.6e} ± {error:.6e}")
1218                return result
1219            
1220            else:
1221                # Symbolic integration
1222                integrand = p * volume_element
1223                
1224                try:
1225                    # Try to integrate over xi first, then x
1226                    integral_xi = integrate(integrand, (xi, -oo, oo))
1227                    integral_x = integrate(integral_xi, (x, -oo, oo))
1228                    return simplify(integral_x)
1229                except:
1230                    print("Warning: Symbolic integration failed. Try numerical=True")
1231                    return integrate(integrand, (xi, -oo, oo), (x, -oo, oo))
1232        
1233        elif self.dim == 2:
1234            x, y = self.vars_x
1235            xi, eta = symbols('xi eta', real=True)
1236            
1237            if volume_element is None:
1238                volume_element = 1 / (4 * pi**2)
1239            
1240            if numerical:
1241                # Numerical integration in 4D
1242                p_func = lambdify((x, y, xi, eta), p, 'numpy')
1243                (x_min, x_max), (y_min, y_max) = x_bounds
1244                (xi_min, xi_max), (eta_min, eta_max) = xi_bounds
1245                
1246                def integrand(eta_val, xi_val, y_val, x_val):
1247                    return p_func(x_val, y_val, xi_val, eta_val)
1248                
1249                result, error = nquad(
1250                    integrand,
1251                    [
1252                        [eta_min, eta_max],
1253                        [xi_min, xi_max],
1254                        [y_min, y_max],
1255                        [x_min, x_max]
1256                    ]
1257                )
1258                
1259                result *= float(volume_element)
1260                print(f"Numerical trace = {result:.6e} ± {error:.6e}")
1261                return result
1262            
1263            else:
1264                # Symbolic integration
1265                integrand = p * volume_element
1266                
1267                try:
1268                    # Integrate in order: eta, xi, y, x
1269                    integral_eta = integrate(integrand, (eta, -oo, oo))
1270                    integral_xi = integrate(integral_eta, (xi, -oo, oo))
1271                    integral_y = integrate(integral_xi, (y, -oo, oo))
1272                    integral_x = integrate(integral_y, (x, -oo, oo))
1273                    return simplify(integral_x)
1274                except:
1275                    print("Warning: Symbolic integration failed. Try numerical=True")
1276                    return integrate(
1277                        integrand,
1278                        (eta, -oo, oo), (xi, -oo, oo),
1279                        (y, -oo, oo), (x, -oo, oo)
1280                    )
1281        
1282        else:
1283            raise NotImplementedError("Only 1D and 2D operators are supported")

Compute the semiclassical trace of the pseudo-differential operator.

The trace formula relates the quantum trace of an operator to a phase-space integral of its symbol, providing a fundamental link between classical and quantum mechanics. This implementation supports both symbolic and numerical integration.

Parameters

volume_element : sympy.Expr, optional Custom volume element for the phase space integration. If None, uses the standard Liouville measure dx dξ/(2π)^d. numerical : bool, default=False If True, perform numerical integration over specified bounds. If False, attempt symbolic integration (may fail for complex symbols). x_bounds : tuple of tuples, optional Spatial integration bounds. For 1D: ((x_min, x_max),) For 2D: ((x_min, x_max), (y_min, y_max)) Required if numerical=True. xi_bounds : tuple of tuples, optional Frequency integration bounds. For 1D: ((xi_min, xi_max),) For 2D: ((xi_min, xi_max), (eta_min, eta_max)) Required if numerical=True.

Returns

sympy.Expr or float The trace of the operator. Returns a symbolic expression if numerical=False, or a float if numerical=True.

Notes

  • The semiclassical trace formula states: Tr(P) = (2π)^{-d} ∫∫ p(x,ξ) dx dξ where d is the spatial dimension and p(x,ξ) is the operator symbol.

  • For 1D: Tr(P) = (1/2π) ∫_{-∞}^{∞} ∫_{-∞}^{∞} p(x,ξ) dx dξ

  • For 2D: Tr(P) = (1/4π²) ∫∫∫∫ p(x,y,ξ,η) dx dy dξ dη

  • This formula is exact for trace-class operators and provides an asymptotic approximation for general pseudo-differential operators.

  • Physical interpretation: the trace counts the "number of states" weighted by the observable p(x,ξ).

  • For projection operators (χ_Ω with χ² = χ), the trace gives the dimension of the range, related to the phase space volume of Ω.

  • The factor (2π)^{-d} comes from the quantum normalization of coherent states / Weyl quantization.

def pseudospectrum_analysis( self, x_grid, lambda_real_range, lambda_imag_range, epsilon_levels=[0.1, 0.01, 0.001, 0.0001], resolution=100, method='spectral', L=None, N=None, use_sparse=False, parallel=True, n_workers=4, adaptive=False, adaptive_threshold=0.5, auto_range=True, plot=True):
1285    def pseudospectrum_analysis(self, x_grid, lambda_real_range, lambda_imag_range,
1286                               epsilon_levels=[0.1, 0.01, 0.001, 0.0001],
1287                               resolution=100, method='spectral', L=None, N=None,
1288                               use_sparse=False, parallel=True, n_workers=4,
1289                               adaptive=False, adaptive_threshold=0.5,
1290                               auto_range=True, plot=True):
1291        """
1292        Compute and visualize the pseudospectrum of the operator.
1293        
1294        Optimizations:
1295        - Uses apply() method instead of manual loops
1296        - Parallel computation of resolvent norms
1297        - Sparse matrix support for large N
1298        - Optional adaptive grid refinement
1299        
1300        Parameters
1301        ----------
1302        x_grid : array
1303            Spatial grid for quantization
1304        lambda_real_range : tuple
1305            (min, max) for real part of λ
1306        lambda_imag_range : tuple
1307            (min, max) for imaginary part of λ
1308        epsilon_levels : list
1309            Levels for ε-pseudospectrum contours
1310        resolution : int
1311            Grid resolution for λ sampling
1312        method : str
1313            'spectral' or 'finite_difference'
1314        L : float, optional
1315            Domain half-length for spectral method
1316        N : int, optional
1317            Number of grid points
1318        use_sparse : bool
1319            Use sparse matrices for large N
1320        parallel : bool
1321            Enable parallel computation
1322        n_workers : int
1323            Number of parallel workers
1324        adaptive : bool
1325            Use adaptive grid refinement
1326        adaptive_threshold : float
1327            Threshold for adaptive refinement
1328            
1329        Returns
1330        -------
1331        dict
1332            Dictionary with pseudospectrum data and operator matrix
1333        """
1334        if self.dim != 1:
1335            raise NotImplementedError('Pseudospectrum analysis currently supports 1D only')
1336        
1337        # Step 1: Build operator matrix
1338        print(f"Building operator matrix using '{method}' method...")
1339        H, x_grid_used, k_grid = self._build_operator_matrix(x_grid, method, L, N)
1340        N_actual = H.shape[0]
1341        
1342        # Step 1.5: Compute eigenvalues FIRST to adjust range if needed
1343        print('Computing eigenvalues...')
1344        eigenvalues = self._compute_eigenvalues(H, use_sparse)
1345        
1346        # Auto-adjust range if requested
1347        if auto_range and eigenvalues is not None:
1348            eig_real_min, eig_real_max = eigenvalues.real.min(), eigenvalues.real.max()
1349            eig_imag_min, eig_imag_max = eigenvalues.imag.min(), eigenvalues.imag.max()
1350            
1351            # Add 20% margin around eigenvalues
1352            margin_real = 0.2 * (eig_real_max - eig_real_min + 1)
1353            margin_imag = max(0.2 * (eig_imag_max - eig_imag_min + 1), 2.0)
1354            
1355            lambda_real_range = (eig_real_min - margin_real, eig_real_max + margin_real)
1356            lambda_imag_range = (eig_imag_min - margin_imag, eig_imag_max + margin_imag)
1357            
1358            print(f'Auto-adjusted λ range:')
1359            print(f'  Re(λ) ∈ [{lambda_real_range[0]:.2f}, {lambda_real_range[1]:.2f}]')
1360            print(f'  Im(λ) ∈ [{lambda_imag_range[0]:.2f}, {lambda_imag_range[1]:.2f}]')
1361        
1362        # Step 2: Compute pseudospectrum with corrected range
1363        print(f'Computing pseudospectrum over {resolution}×{resolution} grid...')
1364        if adaptive:
1365            print('Using adaptive grid refinement...')
1366            Lambda, resolvent_norm, sigma_min_grid = self._compute_pseudospectrum_adaptive(
1367                H, lambda_real_range, lambda_imag_range, resolution,
1368                use_sparse=use_sparse, parallel=parallel, n_workers=n_workers,
1369                threshold=adaptive_threshold
1370            )
1371        else:
1372            Lambda, resolvent_norm, sigma_min_grid = self._compute_pseudospectrum(
1373                H, lambda_real_range, lambda_imag_range, resolution,
1374                use_sparse=use_sparse, parallel=parallel, n_workers=n_workers
1375            )
1376        
1377        # Step 3: Visualize
1378        if plot:
1379            self._plot_pseudospectrum(Lambda, resolvent_norm, sigma_min_grid,
1380                                      epsilon_levels, eigenvalues)
1381        
1382        return {
1383            'lambda_grid': Lambda,
1384            'resolvent_norm': resolvent_norm,
1385            'sigma_min': sigma_min_grid,
1386            'epsilon_levels': epsilon_levels,
1387            'eigenvalues': eigenvalues,
1388            'operator_matrix': H,
1389            'x_grid': x_grid_used,
1390            'k_grid': k_grid
1391        }

Compute and visualize the pseudospectrum of the operator.

Optimizations:

  • Uses apply() method instead of manual loops
  • Parallel computation of resolvent norms
  • Sparse matrix support for large N
  • Optional adaptive grid refinement

Parameters

x_grid : array Spatial grid for quantization lambda_real_range : tuple (min, max) for real part of λ lambda_imag_range : tuple (min, max) for imaginary part of λ epsilon_levels : list Levels for ε-pseudospectrum contours resolution : int Grid resolution for λ sampling method : str 'spectral' or 'finite_difference' L : float, optional Domain half-length for spectral method N : int, optional Number of grid points use_sparse : bool Use sparse matrices for large N parallel : bool Enable parallel computation n_workers : int Number of parallel workers adaptive : bool Use adaptive grid refinement adaptive_threshold : float Threshold for adaptive refinement

Returns

dict Dictionary with pseudospectrum data and operator matrix

def symplectic_flow(self):
1813    def symplectic_flow(self):
1814        """
1815        Compute the Hamiltonian vector field associated with the principal symbol.
1816
1817        This method derives the canonical equations of motion for the phase space variables 
1818        (x, ξ) in 1D or (x, y, ξ, η) in 2D, based on the Hamiltonian formalism. These describe 
1819        how position and frequency variables evolve under the flow generated by the symbol.
1820
1821        Returns
1822        -------
1823        dict
1824            A dictionary containing the components of the Hamiltonian vector field:
1825            - In 1D: keys are 'dx/dt' and 'dxi/dt', corresponding to dx/dt = ∂p/∂ξ and dξ/dt = -∂p/∂x.
1826            - In 2D: keys are 'dx/dt', 'dy/dt', 'dxi/dt', and 'deta/dt', with similar definitions:
1827              dx/dt = ∂p/∂ξ, dy/dt = ∂p/∂η, dξ/dt = -∂p/∂x, dη/dt = -∂p/∂y.
1828
1829        Notes
1830        -----
1831        - The Hamiltonian here is the principal symbol p(x, ξ) itself.
1832        - This flow preserves the symplectic structure of phase space.
1833        """
1834        if self.dim == 1:
1835            x,  = self.vars_x
1836            xi = symbols('xi', real=True)
1837            return {
1838                'dx/dt': diff(self.symbol, xi),
1839                'dxi/dt': -diff(self.symbol, x)
1840            }
1841        elif self.dim == 2:
1842            x, y = self.vars_x
1843            xi, eta = symbols('xi eta', real=True)
1844            return {
1845                'dx/dt': diff(self.symbol, xi),
1846                'dy/dt': diff(self.symbol, eta),
1847                'dxi/dt': -diff(self.symbol, x),
1848                'deta/dt': -diff(self.symbol, y)
1849            }

Compute the Hamiltonian vector field associated with the principal symbol.

This method derives the canonical equations of motion for the phase space variables (x, ξ) in 1D or (x, y, ξ, η) in 2D, based on the Hamiltonian formalism. These describe how position and frequency variables evolve under the flow generated by the symbol.

Returns

dict A dictionary containing the components of the Hamiltonian vector field: - In 1D: keys are 'dx/dt' and 'dxi/dt', corresponding to dx/dt = ∂p/∂ξ and dξ/dt = -∂p/∂x. - In 2D: keys are 'dx/dt', 'dy/dt', 'dxi/dt', and 'deta/dt', with similar definitions: dx/dt = ∂p/∂ξ, dy/dt = ∂p/∂η, dξ/dt = -∂p/∂x, dη/dt = -∂p/∂y.

Notes

  • The Hamiltonian here is the principal symbol p(x, ξ) itself.
  • This flow preserves the symplectic structure of phase space.
def is_elliptic_numerically(self, x_grid, xi_grid, threshold=1e-08):
1851    def is_elliptic_numerically(self, x_grid, xi_grid, threshold=1e-8):
1852        """
1853        Check if the pseudo-differential symbol p(x, ξ) is elliptic over a given grid.
1854    
1855        A symbol is considered elliptic if its magnitude |p(x, ξ)| remains bounded away from zero 
1856        across all points in the spatial-frequency domain. This method evaluates the symbol on a 
1857        grid of spatial and frequency coordinates and checks whether its minimum absolute value 
1858        exceeds a specified threshold.
1859    
1860        Resampling is applied to large grids to prevent excessive memory usage, particularly in 2D.
1861    
1862        Parameters
1863        ----------
1864        x_grid : ndarray
1865            Spatial grid: either a 1D array (x) or a tuple of two 1D arrays (x, y).
1866        xi_grid : ndarray
1867            Frequency grid: either a 1D array (ξ) or a tuple of two 1D arrays (ξ, η).
1868        threshold : float, optional
1869            Minimum acceptable value for |p(x, ξ)|. If the smallest evaluated symbol value falls below this,
1870            the symbol is not considered elliptic.
1871    
1872        Returns
1873        -------
1874        bool
1875            True if the symbol is elliptic on the resampled grid, False otherwise.
1876        """
1877        RESAMPLE_SIZE = 32  # Reduced size to prevent memory explosion
1878        
1879        if self.dim == 1:
1880            x_vals = x_grid
1881            xi_vals = xi_grid
1882            # Resampling if necessary
1883            if len(x_vals) > RESAMPLE_SIZE:
1884                x_vals = np.linspace(x_vals.min(), x_vals.max(), RESAMPLE_SIZE)
1885            if len(xi_vals) > RESAMPLE_SIZE:
1886                xi_vals = np.linspace(xi_vals.min(), xi_vals.max(), RESAMPLE_SIZE)
1887        
1888            X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
1889            symbol_vals = self.p_func(X, XI)
1890        
1891        elif self.dim == 2:
1892            x_vals, y_vals = x_grid
1893            xi_vals, eta_vals = xi_grid
1894        
1895            # Spatial resampling
1896            if len(x_vals) > RESAMPLE_SIZE:
1897                x_vals = np.linspace(x_vals.min(), x_vals.max(), RESAMPLE_SIZE)
1898            if len(y_vals) > RESAMPLE_SIZE:
1899                y_vals = np.linspace(y_vals.min(), y_vals.max(), RESAMPLE_SIZE)
1900        
1901            # Frequency resampling
1902            if len(xi_vals) > RESAMPLE_SIZE:
1903                xi_vals = np.linspace(xi_vals.min(), xi_vals.max(), RESAMPLE_SIZE)
1904            if len(eta_vals) > RESAMPLE_SIZE:
1905                eta_vals = np.linspace(eta_vals.min(), eta_vals.max(), RESAMPLE_SIZE)
1906        
1907            X, Y, XI, ETA = np.meshgrid(x_vals, y_vals, xi_vals, eta_vals, indexing='ij')
1908            symbol_vals = self.p_func(X, Y, XI, ETA)
1909        
1910        min_abs_val = np.min(np.abs(symbol_vals))
1911        return min_abs_val > threshold

Check if the pseudo-differential symbol p(x, ξ) is elliptic over a given grid.

A symbol is considered elliptic if its magnitude |p(x, ξ)| remains bounded away from zero across all points in the spatial-frequency domain. This method evaluates the symbol on a grid of spatial and frequency coordinates and checks whether its minimum absolute value exceeds a specified threshold.

Resampling is applied to large grids to prevent excessive memory usage, particularly in 2D.

Parameters

x_grid : ndarray Spatial grid: either a 1D array (x) or a tuple of two 1D arrays (x, y). xi_grid : ndarray Frequency grid: either a 1D array (ξ) or a tuple of two 1D arrays (ξ, η). threshold : float, optional Minimum acceptable value for |p(x, ξ)|. If the smallest evaluated symbol value falls below this, the symbol is not considered elliptic.

Returns

bool True if the symbol is elliptic on the resampled grid, False otherwise.

def is_self_adjoint(self, tol=1e-10):
1914    def is_self_adjoint(self, tol=1e-10):
1915        """
1916        Check whether the pseudo-differential operator is formally self-adjoint (Hermitian).
1917
1918        A self-adjoint operator satisfies P = P*, where P* is the formal adjoint of P.
1919        This property is essential for ensuring real-valued eigenvalues and stable evolution 
1920        in quantum mechanics and symmetric wave propagation.
1921
1922        Parameters
1923        ----------
1924        tol : float
1925            Tolerance for symbolic comparison between P and P*. Small numerical differences 
1926            below this threshold are considered equal.
1927
1928        Returns
1929        -------
1930        bool
1931            True if the symbol p(x, ξ) equals its formal adjoint p*(x, ξ) within the given tolerance,
1932            indicating that the operator is self-adjoint.
1933
1934        Notes:
1935        - The formal adjoint is computed via conjugation and asymptotic expansion at infinity in ξ.
1936        - Symbolic simplification is used to verify equality, ensuring robustness against superficial 
1937          expression differences.
1938        """
1939        p = self.symbol
1940        p_star = self.formal_adjoint()
1941        return simplify(p - p_star).equals(0)

Check whether the pseudo-differential operator is formally self-adjoint (Hermitian).

A self-adjoint operator satisfies P = P, where P is the formal adjoint of P. This property is essential for ensuring real-valued eigenvalues and stable evolution in quantum mechanics and symmetric wave propagation.

Parameters

tol : float Tolerance for symbolic comparison between P and P*. Small numerical differences below this threshold are considered equal.

Returns

bool True if the symbol p(x, ξ) equals its formal adjoint p*(x, ξ) within the given tolerance, indicating that the operator is self-adjoint.

Notes:

  • The formal adjoint is computed via conjugation and asymptotic expansion at infinity in ξ.
  • Symbolic simplification is used to verify equality, ensuring robustness against superficial expression differences.
def visualize_fiber(self, x_grid, xi_grid, x0=0.0, y0=0.0):
1943    def visualize_fiber(self, x_grid, xi_grid, x0=0.0, y0=0.0):
1944        """
1945        Plot the cotangent fiber structure at a fixed spatial point (x₀[, y₀]).
1946    
1947        This visualization shows how the symbol p(x, ξ) behaves on the cotangent fiber 
1948        above a fixed spatial point. In microlocal analysis, this provides insight into 
1949        the frequency content of the operator at that location.
1950    
1951        Parameters
1952        ----------
1953        x_grid : ndarray
1954            Spatial grid values (1D) for evaluation in 1D case.
1955        xi_grid : ndarray
1956            Frequency grid values (1D) for evaluation in both 1D and 2D cases.
1957        x0 : float, optional
1958            Fixed x-coordinate of the base point in space (1D or 2D).
1959        y0 : float, optional
1960            Fixed y-coordinate of the base point in space (2D only).
1961    
1962        Notes
1963        -----
1964        - In 1D: Displays |p(x, ξ)| over the (x, ξ) phase plane near the fixed point.
1965        - In 2D: Fixes (x₀, y₀) and evaluates p(x₀, y₀, ξ, η), showing the fiber over that point.
1966        - The color map represents the magnitude of the symbol, highlighting regions where it vanishes or becomes singular.
1967    
1968        Raises
1969        ------
1970        NotImplementedError
1971            If called in 2D with missing or improperly formatted grids.
1972        """
1973        if self.dim == 1:
1974            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
1975            symbol_vals = self.p_func(X, XI)
1976            plt.contourf(X, XI, np.abs(symbol_vals), levels=50, cmap='viridis')
1977            plt.colorbar(label='|Symbol|')
1978            plt.xlabel('x (position)')
1979            plt.ylabel('ξ (frequency)')
1980            plt.title('Cotangent Fiber Structure')
1981            plt.show()
1982        elif self.dim == 2:
1983            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, xi_grid)
1984            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
1985            plt.contourf(xi_grid, xi_grid, np.abs(symbol_vals), levels=50, cmap='viridis')
1986            plt.colorbar(label='|Symbol|')
1987            plt.xlabel('ξ')
1988            plt.ylabel('η')
1989            plt.title(f'Cotangent Fiber at x={x0}, y={y0}')
1990            plt.show()

Plot the cotangent fiber structure at a fixed spatial point (x₀[, y₀]).

This visualization shows how the symbol p(x, ξ) behaves on the cotangent fiber above a fixed spatial point. In microlocal analysis, this provides insight into the frequency content of the operator at that location.

Parameters

x_grid : ndarray Spatial grid values (1D) for evaluation in 1D case. xi_grid : ndarray Frequency grid values (1D) for evaluation in both 1D and 2D cases. x0 : float, optional Fixed x-coordinate of the base point in space (1D or 2D). y0 : float, optional Fixed y-coordinate of the base point in space (2D only).

Notes

  • In 1D: Displays |p(x, ξ)| over the (x, ξ) phase plane near the fixed point.
  • In 2D: Fixes (x₀, y₀) and evaluates p(x₀, y₀, ξ, η), showing the fiber over that point.
  • The color map represents the magnitude of the symbol, highlighting regions where it vanishes or becomes singular.

Raises

NotImplementedError If called in 2D with missing or improperly formatted grids.

def visualize_symbol_amplitude(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
1992    def visualize_symbol_amplitude(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
1993        """
1994        Display the modulus |p(x, ξ)| or |p(x, y, ξ₀, η₀)| as a color map.
1995    
1996        This method visualizes the amplitude of the pseudodifferential operator's symbol 
1997        in either 1D or 2D spatial configuration. In 2D, the frequency variables are fixed 
1998        to specified values (ξ₀, η₀) for visualization purposes.
1999    
2000        Parameters
2001        ----------
2002        x_grid, y_grid : ndarray
2003            Spatial grids over which to evaluate the symbol. y_grid is optional and used only in 2D.
2004        xi_grid, eta_grid : ndarray
2005            Frequency grids. In 2D, these define the domain over which the symbol is evaluated,
2006            but the visualization fixes ξ = ξ₀ and η = η₀.
2007        xi0, eta0 : float, optional
2008            Fixed frequency values for slicing in 2D visualization. Defaults to zero.
2009    
2010        Notes
2011        -----
2012        - In 1D: Visualizes |p(x, ξ)| over the (x, ξ) grid.
2013        - In 2D: Visualizes |p(x, y, ξ₀, η₀)| at fixed frequencies ξ₀ and η₀.
2014        - The color intensity represents the magnitude of the symbol, highlighting regions where the symbol is large or small.
2015        """
2016        if self.dim == 1:
2017            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
2018            symbol_vals = self.p_func(X, XI) 
2019            plt.pcolormesh(X, XI, np.abs(symbol_vals), shading='auto')
2020            plt.colorbar(label='|Symbol|')
2021            plt.xlabel('x')
2022            plt.ylabel('ξ')
2023            plt.title('Symbol Amplitude |p(x, ξ)|')
2024            plt.show()
2025        elif self.dim == 2:
2026            X, Y = np.meshgrid(x_grid, y_grid, indexing='ij')
2027            XI = np.full_like(X, xi0)
2028            ETA = np.full_like(Y, eta0)
2029            symbol_vals = self.p_func(X, Y, XI, ETA)
2030            plt.pcolormesh(X, Y, np.abs(symbol_vals), shading='auto')
2031            plt.colorbar(label='|Symbol|')
2032            plt.xlabel('x')
2033            plt.ylabel('y')
2034            plt.title(f'Symbol Amplitude at ξ={xi0}, η={eta0}')
2035            plt.show()

Display the modulus |p(x, ξ)| or |p(x, y, ξ₀, η₀)| as a color map.

This method visualizes the amplitude of the pseudodifferential operator's symbol in either 1D or 2D spatial configuration. In 2D, the frequency variables are fixed to specified values (ξ₀, η₀) for visualization purposes.

Parameters

x_grid, y_grid : ndarray Spatial grids over which to evaluate the symbol. y_grid is optional and used only in 2D. xi_grid, eta_grid : ndarray Frequency grids. In 2D, these define the domain over which the symbol is evaluated, but the visualization fixes ξ = ξ₀ and η = η₀. xi0, eta0 : float, optional Fixed frequency values for slicing in 2D visualization. Defaults to zero.

Notes

  • In 1D: Visualizes |p(x, ξ)| over the (x, ξ) grid.
  • In 2D: Visualizes |p(x, y, ξ₀, η₀)| at fixed frequencies ξ₀ and η₀.
  • The color intensity represents the magnitude of the symbol, highlighting regions where the symbol is large or small.
def visualize_phase(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
2037    def visualize_phase(self, x_grid, xi_grid, y_grid=None, eta_grid=None, xi0=0.0, eta0=0.0):
2038        """
2039        Plot the phase (argument) of the pseudodifferential operator's symbol p(x, ξ) or p(x, y, ξ, η).
2040
2041        This visualization helps in understanding the oscillatory behavior and regularity 
2042        properties of the operator in phase space. The phase is displayed modulo 2π using 
2043        a cyclic colormap ('twilight') to emphasize its periodic nature.
2044
2045        Parameters
2046        ----------
2047        x_grid : ndarray
2048            1D array of spatial coordinates (x).
2049        xi_grid : ndarray
2050            1D array of frequency coordinates (ξ).
2051        y_grid : ndarray, optional
2052            2D spatial grid for y-coordinate (in 2D problems). Default is None.
2053        eta_grid : ndarray, optional
2054            2D frequency grid for η (in 2D problems). Not used directly but kept for API consistency.
2055        xi0 : float, optional
2056            Fixed value of ξ for slicing in 2D visualization. Default is 0.0.
2057        eta0 : float, optional
2058            Fixed value of η for slicing in 2D visualization. Default is 0.0.
2059
2060        Notes:
2061        - In 1D: Displays arg(p(x, ξ)) over the (x, ξ) phase plane.
2062        - In 2D: Displays arg(p(x, y, ξ₀, η₀)) for fixed frequency values (ξ₀, η₀).
2063        - Uses plt.pcolormesh with 'twilight' colormap to represent angles from -π to π.
2064
2065        Raises:
2066        - NotImplementedError: If the spatial dimension is not 1D or 2D.
2067        """
2068        if self.dim == 1:
2069            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
2070            symbol_vals = self.p_func(X, XI) 
2071            plt.pcolormesh(X, XI, np.angle(symbol_vals), shading='auto', cmap='twilight')
2072            plt.colorbar(label='arg(Symbol) [rad]')
2073            plt.xlabel('x')
2074            plt.ylabel('ξ')
2075            plt.title('Phase Portrait (arg p(x, ξ))')
2076            plt.show()
2077        elif self.dim == 2:
2078            X, Y = np.meshgrid(x_grid, y_grid, indexing='ij')
2079            XI = np.full_like(X, xi0)
2080            ETA = np.full_like(Y, eta0)
2081            symbol_vals = self.p_func(X, Y, XI, ETA)
2082            plt.pcolormesh(X, Y, np.angle(symbol_vals), shading='auto', cmap='twilight')
2083            plt.colorbar(label='arg(Symbol) [rad]')
2084            plt.xlabel('x')
2085            plt.ylabel('y')
2086            plt.title(f'Phase Portrait at ξ={xi0}, η={eta0}')
2087            plt.show()

Plot the phase (argument) of the pseudodifferential operator's symbol p(x, ξ) or p(x, y, ξ, η).

This visualization helps in understanding the oscillatory behavior and regularity properties of the operator in phase space. The phase is displayed modulo 2π using a cyclic colormap ('twilight') to emphasize its periodic nature.

Parameters

x_grid : ndarray 1D array of spatial coordinates (x). xi_grid : ndarray 1D array of frequency coordinates (ξ). y_grid : ndarray, optional 2D spatial grid for y-coordinate (in 2D problems). Default is None. eta_grid : ndarray, optional 2D frequency grid for η (in 2D problems). Not used directly but kept for API consistency. xi0 : float, optional Fixed value of ξ for slicing in 2D visualization. Default is 0.0. eta0 : float, optional Fixed value of η for slicing in 2D visualization. Default is 0.0.

Notes:

  • In 1D: Displays arg(p(x, ξ)) over the (x, ξ) phase plane.
  • In 2D: Displays arg(p(x, y, ξ₀, η₀)) for fixed frequency values (ξ₀, η₀).
  • Uses plt.pcolormesh with 'twilight' colormap to represent angles from -π to π.

Raises:

  • NotImplementedError: If the spatial dimension is not 1D or 2D.
def visualize_characteristic_set( self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0, levels=[0.1]):
2089    def visualize_characteristic_set(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0, levels=[1e-1]):
2090        """
2091        Visualize the characteristic set of the pseudo-differential symbol, defined as the approximate zero set p(x, ξ) ≈ 0.
2092    
2093        In microlocal analysis, the characteristic set is the locus of points in phase space (x, ξ) where the symbol p(x, ξ) vanishes,
2094        playing a key role in understanding propagation of singularities.
2095    
2096        Parameters
2097        ----------
2098        x_grid : ndarray
2099            Spatial grid values (1D array) for plotting in 1D or evaluation point in 2D.
2100        xi_grid : ndarray
2101            Frequency variable grid values (1D array) used to construct the frequency domain.
2102        x0 : float, optional
2103            Fixed spatial coordinate in 2D case for evaluating the symbol at a specific x position.
2104        y0 : float, optional
2105            Fixed spatial coordinate in 2D case for evaluating the symbol at a specific y position.
2106    
2107        Notes
2108        -----
2109        - For 1D, this method plots the contour of |p(x, ξ)| = ε with ε = 1e-5 over the (x, ξ) plane.
2110        - For 2D, it evaluates the symbol at fixed (x₀, y₀) and plots the characteristic set in the (ξ, η) frequency plane.
2111        - This visualization helps identify directions of degeneracy or hypoellipticity of the operator.
2112    
2113        Raises
2114        ------
2115        NotImplementedError
2116            If called on a solver with dimensionality other than 1D or 2D.
2117    
2118        Displays
2119        ------
2120        A matplotlib contour plot showing either:
2121            - The characteristic curve in the (x, ξ) phase plane (1D),
2122            - The characteristic surface slice in the (ξ, η) frequency plane at (x₀, y₀) (2D).
2123        """
2124        if self.dim == 1:
2125            x_grid = np.asarray(x_grid)
2126            xi_grid = np.asarray(xi_grid)
2127            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
2128            symbol_vals = self.p_func(X, XI) 
2129            plt.contour(X, XI, np.abs(symbol_vals), levels=levels, colors='red')
2130            plt.xlabel('x')
2131            plt.ylabel('ξ')
2132            plt.title('Characteristic Set (p(x, ξ) ≈ 0)')
2133            plt.grid(True)
2134            plt.show()
2135        elif self.dim == 2:
2136            if eta_grid is None:
2137                raise ValueError("eta_grid must be provided for 2D visualization.")
2138            xi_grid = np.asarray(xi_grid)
2139            eta_grid = np.asarray(eta_grid)
2140            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, eta_grid, indexing='ij')
2141            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
2142            plt.contour(xi_grid, eta_grid, np.abs(symbol_vals), levels=levels, colors='red')
2143            plt.xlabel('ξ')
2144            plt.ylabel('η')
2145            plt.title(f'Characteristic Set at x={x0}, y={y0}')
2146            plt.grid(True)
2147            plt.show()
2148        else:
2149            raise NotImplementedError("Only 1D/2D characteristic sets supported.")

Visualize the characteristic set of the pseudo-differential symbol, defined as the approximate zero set p(x, ξ) ≈ 0.

In microlocal analysis, the characteristic set is the locus of points in phase space (x, ξ) where the symbol p(x, ξ) vanishes, playing a key role in understanding propagation of singularities.

Parameters

x_grid : ndarray Spatial grid values (1D array) for plotting in 1D or evaluation point in 2D. xi_grid : ndarray Frequency variable grid values (1D array) used to construct the frequency domain. x0 : float, optional Fixed spatial coordinate in 2D case for evaluating the symbol at a specific x position. y0 : float, optional Fixed spatial coordinate in 2D case for evaluating the symbol at a specific y position.

Notes

  • For 1D, this method plots the contour of |p(x, ξ)| = ε with ε = 1e-5 over the (x, ξ) plane.
  • For 2D, it evaluates the symbol at fixed (x₀, y₀) and plots the characteristic set in the (ξ, η) frequency plane.
  • This visualization helps identify directions of degeneracy or hypoellipticity of the operator.

Raises

NotImplementedError If called on a solver with dimensionality other than 1D or 2D.

Displays

A matplotlib contour plot showing either: - The characteristic curve in the (x, ξ) phase plane (1D), - The characteristic surface slice in the (ξ, η) frequency plane at (x₀, y₀) (2D).

def visualize_characteristic_gradient(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0):
2151    def visualize_characteristic_gradient(self, x_grid, xi_grid, y_grid=None, eta_grid=None, y0=0.0, x0=0.0):
2152        """
2153        Visualize the norm of the gradient of the symbol in phase space.
2154        
2155        This method computes the magnitude of the gradient |∇p| of a pseudo-differential 
2156        symbol p(x, ξ) in 1D or p(x, y, ξ, η) in 2D. The resulting colormap reveals 
2157        regions where the symbol varies rapidly or remains nearly stationary, 
2158        which is particularly useful for analyzing characteristic sets.
2159        
2160        Parameters
2161        ----------
2162        x_grid : numpy.ndarray
2163            1D array of spatial coordinates for the x-direction.
2164        xi_grid : numpy.ndarray
2165            1D array of frequency coordinates (ξ).
2166        y_grid : numpy.ndarray, optional
2167            1D array of spatial coordinates for the y-direction (used in 2D mode). Default is None.
2168        eta_grid : numpy.ndarray, optional
2169            1D array of frequency coordinates (η) for the 2D case. Default is None.
2170        x0 : float, optional
2171            Fixed x-coordinate for evaluating the symbol in 2D. Default is 0.0.
2172        y0 : float, optional
2173            Fixed y-coordinate for evaluating the symbol in 2D. Default is 0.0.
2174        
2175        Returns
2176        -------
2177        None
2178            Displays a 2D colormap of |∇p| over the relevant phase-space domain.
2179        
2180        Notes
2181        -----
2182        - In 1D, the full gradient ∇p = (∂ₓp, ∂ξp) is computed over the (x, ξ) grid.
2183        - In 2D, the gradient ∇p = (∂ξp, ∂ηp) is computed at a fixed spatial point (x₀, y₀) over the (ξ, η) grid.
2184        - Numerical differentiation is performed using `np.gradient`.
2185        - High values of |∇p| indicate rapid variation of the symbol, while low values typically suggest characteristic regions.
2186        """
2187        if self.dim == 1:
2188            X, XI = np.meshgrid(x_grid, xi_grid, indexing='ij')
2189            symbol_vals = self.p_func(X, XI)
2190            grad_x = np.gradient(symbol_vals, axis=0)
2191            grad_xi = np.gradient(symbol_vals, axis=1)
2192            grad_norm = np.sqrt(grad_x**2 + grad_xi**2)
2193            plt.pcolormesh(X, XI, grad_norm, cmap='inferno', shading='auto')
2194            plt.colorbar(label='|∇p|')
2195            plt.xlabel('x')
2196            plt.ylabel('ξ')
2197            plt.title('Gradient Norm (High Near Zeros)')
2198            plt.grid(True)
2199            plt.show()
2200        elif self.dim == 2:
2201            xi_grid2, eta_grid2 = np.meshgrid(xi_grid, eta_grid, indexing='ij')
2202            symbol_vals = self.p_func(x0, y0, xi_grid2, eta_grid2)
2203            grad_xi = np.gradient(symbol_vals, axis=0)
2204            grad_eta = np.gradient(symbol_vals, axis=1)
2205            grad_norm = np.sqrt(np.abs(grad_xi)**2 + np.abs(grad_eta)**2)
2206            plt.pcolormesh(xi_grid, eta_grid, grad_norm, cmap='inferno', shading='auto')
2207            plt.colorbar(label='|∇p|')
2208            plt.xlabel('ξ')
2209            plt.ylabel('η')
2210            plt.title(f'Gradient Norm at x={x0}, y={y0}')
2211            plt.grid(True)
2212            plt.show()

Visualize the norm of the gradient of the symbol in phase space.

This method computes the magnitude of the gradient |∇p| of a pseudo-differential symbol p(x, ξ) in 1D or p(x, y, ξ, η) in 2D. The resulting colormap reveals regions where the symbol varies rapidly or remains nearly stationary, which is particularly useful for analyzing characteristic sets.

Parameters

x_grid : numpy.ndarray 1D array of spatial coordinates for the x-direction. xi_grid : numpy.ndarray 1D array of frequency coordinates (ξ). y_grid : numpy.ndarray, optional 1D array of spatial coordinates for the y-direction (used in 2D mode). Default is None. eta_grid : numpy.ndarray, optional 1D array of frequency coordinates (η) for the 2D case. Default is None. x0 : float, optional Fixed x-coordinate for evaluating the symbol in 2D. Default is 0.0. y0 : float, optional Fixed y-coordinate for evaluating the symbol in 2D. Default is 0.0.

Returns

None Displays a 2D colormap of |∇p| over the relevant phase-space domain.

Notes

  • In 1D, the full gradient ∇p = (∂ₓp, ∂ξp) is computed over the (x, ξ) grid.
  • In 2D, the gradient ∇p = (∂ξp, ∂ηp) is computed at a fixed spatial point (x₀, y₀) over the (ξ, η) grid.
  • Numerical differentiation is performed using np.gradient.
  • High values of |∇p| indicate rapid variation of the symbol, while low values typically suggest characteristic regions.
def plot_hamiltonian_flow( self, x0=0.0, xi0=5.0, y0=0.0, eta0=0.0, tmax=1.0, n_steps=100, show_field=True):
2214    def plot_hamiltonian_flow(self, x0=0.0, xi0=5.0, y0=0.0, eta0=0.0, tmax=1.0, n_steps=100, show_field=True):
2215        """
2216        Integrate and plot the Hamiltonian trajectories of the symbol in phase space.
2217
2218        This method numerically integrates the Hamiltonian vector field derived from 
2219        the operator's symbol to visualize how singularities propagate under the flow. 
2220        It supports both 1D and 2D problems.
2221
2222        Parameters
2223        ----------
2224        x0, xi0 : float
2225            Initial position and frequency (momentum) in 1D.
2226        y0, eta0 : float, optional
2227            Initial position and frequency in 2D; defaults to zero.
2228        tmax : float
2229            Final integration time for the ODE solver.
2230        n_steps : int
2231            Number of time steps used in the integration.
2232
2233        Notes
2234        -----
2235        - The Hamiltonian vector field is obtained from the symplectic flow of the symbol.
2236        - If the field is complex-valued, only its real part is used for integration.
2237        - In 1D, the trajectory is plotted in (x, ξ) phase space.
2238        - In 2D, the spatial trajectory (x(t), y(t)) is shown along with instantaneous 
2239          momentum vectors (ξ(t), η(t)) using a quiver plot.
2240
2241        Raises
2242        ------
2243        NotImplementedError
2244            If the spatial dimension is not 1D or 2D.
2245
2246        Displays
2247        --------
2248        matplotlib plot
2249            Phase space trajectory(ies) showing the evolution of position and momentum 
2250            under the Hamiltonian dynamics.
2251        """
2252        def make_real(expr):
2253            from sympy import re, simplify
2254            expr = expr.doit(deep=True)
2255            return simplify(re(expr))
2256    
2257        H = self.symplectic_flow()
2258    
2259        if any(im(H[k]) != 0 for k in H):
2260            print("⚠️ The Hamiltonian field is complex. Only the real part is used for integration.")
2261    
2262        if self.dim == 1:
2263            x, = self.vars_x
2264            xi = symbols('xi', real=True)
2265    
2266            dxdt_expr = make_real(H['dx/dt'])
2267            dxidt_expr = make_real(H['dxi/dt'])
2268    
2269            dxdt = lambdify((x, xi), dxdt_expr, 'numpy')
2270            dxidt = lambdify((x, xi), dxidt_expr, 'numpy')
2271    
2272            def hamilton(t, Y):
2273                x, xi = Y
2274                return [dxdt(x, xi), dxidt(x, xi)]
2275    
2276            sol = solve_ivp(hamilton, [0, tmax], [x0, xi0], t_eval=np.linspace(0, tmax, n_steps))
2277
2278            if sol.status != 0:
2279                print(f"⚠️ Integration warning: {sol.message}")
2280            
2281            n_points = sol.y.shape[1]
2282            if n_points < n_steps:
2283                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2284                n_steps = n_points
2285
2286            x_vals, xi_vals = sol.y
2287    
2288            plt.plot(x_vals, xi_vals)
2289            plt.xlabel("x")
2290            plt.ylabel("ξ")
2291            plt.title("Hamiltonian Flow in Phase Space (1D)")
2292            plt.grid(True)
2293            plt.show()
2294    
2295        elif self.dim == 2:
2296            x, y = self.vars_x
2297            xi, eta = symbols('xi eta', real=True)
2298    
2299            dxdt = lambdify((x, y, xi, eta), make_real(H['dx/dt']), 'numpy')
2300            dydt = lambdify((x, y, xi, eta), make_real(H['dy/dt']), 'numpy')
2301            dxidt = lambdify((x, y, xi, eta), make_real(H['dxi/dt']), 'numpy')
2302            detadt = lambdify((x, y, xi, eta), make_real(H['deta/dt']), 'numpy')
2303    
2304            def hamilton(t, Y):
2305                x, y, xi, eta = Y
2306                return [
2307                    dxdt(x, y, xi, eta),
2308                    dydt(x, y, xi, eta),
2309                    dxidt(x, y, xi, eta),
2310                    detadt(x, y, xi, eta)
2311                ]
2312    
2313            sol = solve_ivp(hamilton, [0, tmax], [x0, y0, xi0, eta0], t_eval=np.linspace(0, tmax, n_steps))
2314
2315            if sol.status != 0:
2316                print(f"⚠️ Integration warning: {sol.message}")
2317            
2318            n_points = sol.y.shape[1]
2319            if n_points < n_steps:
2320                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2321                n_steps = n_points
2322
2323            x_vals, y_vals, xi_vals, eta_vals = sol.y
2324    
2325            plt.plot(x_vals, y_vals, label='Position')
2326            plt.quiver(x_vals, y_vals, xi_vals, eta_vals, scale=20, width=0.003, alpha=0.5, color='r')
2327            
2328            # Vector field of the flow (optional)
2329            if show_field:
2330                X, Y = np.meshgrid(np.linspace(min(x_vals), max(x_vals), 20),
2331                                   np.linspace(min(y_vals), max(y_vals), 20))
2332                XI, ETA = xi0 * np.ones_like(X), eta0 * np.ones_like(Y)
2333                U = dxdt(X, Y, XI, ETA)
2334                V = dydt(X, Y, XI, ETA)
2335                plt.quiver(X, Y, U, V, color='gray', alpha=0.2, scale=30, width=0.002)
2336
2337            plt.xlabel("x")
2338            plt.ylabel("y")
2339            plt.title("Hamiltonian Flow in Phase Space (2D)")
2340            plt.legend()
2341            plt.grid(True)
2342            plt.axis('equal')
2343            plt.show()

Integrate and plot the Hamiltonian trajectories of the symbol in phase space.

This method numerically integrates the Hamiltonian vector field derived from the operator's symbol to visualize how singularities propagate under the flow. It supports both 1D and 2D problems.

Parameters

x0, xi0 : float Initial position and frequency (momentum) in 1D. y0, eta0 : float, optional Initial position and frequency in 2D; defaults to zero. tmax : float Final integration time for the ODE solver. n_steps : int Number of time steps used in the integration.

Notes

  • The Hamiltonian vector field is obtained from the symplectic flow of the symbol.
  • If the field is complex-valued, only its real part is used for integration.
  • In 1D, the trajectory is plotted in (x, ξ) phase space.
  • In 2D, the spatial trajectory (x(t), y(t)) is shown along with instantaneous momentum vectors (ξ(t), η(t)) using a quiver plot.

Raises

NotImplementedError If the spatial dimension is not 1D or 2D.

Displays

matplotlib plot Phase space trajectory(ies) showing the evolution of position and momentum under the Hamiltonian dynamics.

def plot_symplectic_vector_field(self, xlim=(-2, 2), klim=(-5, 5), density=30):
2345    def plot_symplectic_vector_field(self, xlim=(-2, 2), klim=(-5, 5), density=30):
2346        """
2347        Visualize the symplectic vector field (Hamiltonian vector field) associated with the operator's symbol.
2348
2349        The plotted vector field corresponds to (∂_ξ p, -∂_x p), where p(x, ξ) is the principal symbol 
2350        of the pseudo-differential operator. This field governs the bicharacteristic flow in phase space.
2351
2352        Parameters
2353        ----------
2354        xlim : tuple of float
2355            Range for spatial variable x, as (x_min, x_max).
2356        klim : tuple of float
2357            Range for frequency variable ξ, as (ξ_min, ξ_max).
2358        density : int
2359            Number of grid points per axis for the visualization grid.
2360
2361        Raises
2362        ------
2363        NotImplementedError
2364            If called on a 2D operator (currently only 1D implementation available).
2365
2366        Notes
2367        -----
2368        - Only supports one-dimensional operators.
2369        - Uses symbolic differentiation to compute ∂_ξ p and ∂_x p.
2370        - Numerical evaluation is done via lambdify with NumPy backend.
2371        - Visualization uses matplotlib quiver plot to show vector directions.
2372        """
2373        x_vals = np.linspace(*xlim, density)
2374        xi_vals = np.linspace(*klim, density)
2375        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
2376
2377        if self.dim != 1:
2378            raise NotImplementedError("Only 1D version implemented.")
2379
2380        x, = self.vars_x
2381        xi = symbols('xi', real=True)
2382        H = self.symplectic_flow()
2383        dxdt = lambdify((x, xi), simplify(H['dx/dt']), 'numpy')
2384        dxidt = lambdify((x, xi), simplify(H['dxi/dt']), 'numpy')
2385
2386        U = dxdt(X, XI)
2387        V = dxidt(X, XI)
2388
2389        plt.quiver(X, XI, U, V, scale=10, width=0.005)
2390        plt.xlabel('x')
2391        plt.ylabel(r'$\xi$')
2392        plt.title("Symplectic Vector Field (1D)")
2393        plt.grid(True)
2394        plt.show()

Visualize the symplectic vector field (Hamiltonian vector field) associated with the operator's symbol.

The plotted vector field corresponds to (∂_ξ p, -∂_x p), where p(x, ξ) is the principal symbol of the pseudo-differential operator. This field governs the bicharacteristic flow in phase space.

Parameters

xlim : tuple of float Range for spatial variable x, as (x_min, x_max). klim : tuple of float Range for frequency variable ξ, as (ξ_min, ξ_max). density : int Number of grid points per axis for the visualization grid.

Raises

NotImplementedError If called on a 2D operator (currently only 1D implementation available).

Notes

  • Only supports one-dimensional operators.
  • Uses symbolic differentiation to compute ∂_ξ p and ∂_x p.
  • Numerical evaluation is done via lambdify with NumPy backend.
  • Visualization uses matplotlib quiver plot to show vector directions.
def visualize_micro_support(self, xlim=(-2, 2), klim=(-10, 10), threshold=0.001, density=300):
2396    def visualize_micro_support(self, xlim=(-2, 2), klim=(-10, 10), threshold=1e-3, density=300):
2397        """
2398        Visualize the micro-support of the operator by plotting the inverse of the symbol magnitude 1 / |p(x, ξ)|.
2399    
2400        The micro-support provides insight into the singularities of a pseudo-differential operator 
2401        in phase space (x, ξ). Regions where |p(x, ξ)| is small correspond to large values in 1/|p(x, ξ)|,
2402        highlighting areas of significant operator influence or singularity.
2403    
2404        Parameters
2405        ----------
2406        xlim : tuple
2407            Spatial domain limits (x_min, x_max).
2408        klim : tuple
2409            Frequency domain limits (ξ_min, ξ_max).
2410        threshold : float
2411            Threshold below which |p(x, ξ)| is considered effectively zero; used for numerical stability.
2412        density : int
2413            Number of grid points along each axis for visualization resolution.
2414    
2415        Raises
2416        ------
2417        NotImplementedError
2418            If called on a solver with dimension greater than 1 (only 1D visualization is supported).
2419    
2420        Notes
2421        -----
2422        - This method evaluates the symbol p(x, ξ) over a grid and plots its reciprocal to emphasize 
2423          regions where the symbol is near zero.
2424        - A small constant (1e-10) is added to the denominator to avoid division by zero.
2425        - The resulting plot helps identify characteristic sets.
2426        """
2427        if self.dim != 1:
2428            raise NotImplementedError("Only 1D micro-support visualization implemented.")
2429
2430        x_vals = np.linspace(*xlim, density)
2431        xi_vals = np.linspace(*klim, density)
2432        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
2433        Z = np.abs(self.p_func(X, XI))
2434
2435        plt.contourf(X, XI, 1 / (Z + 1e-10), levels=100, cmap='inferno')
2436        plt.colorbar(label=r'$1/|p(x,\xi)|$')
2437        plt.xlabel('x')
2438        plt.ylabel(r'$\xi$')
2439        plt.title("Micro-Support Estimate (1/|Symbol|)")
2440        plt.show()

Visualize the micro-support of the operator by plotting the inverse of the symbol magnitude 1 / |p(x, ξ)|.

The micro-support provides insight into the singularities of a pseudo-differential operator in phase space (x, ξ). Regions where |p(x, ξ)| is small correspond to large values in 1/|p(x, ξ)|, highlighting areas of significant operator influence or singularity.

Parameters

xlim : tuple Spatial domain limits (x_min, x_max). klim : tuple Frequency domain limits (ξ_min, ξ_max). threshold : float Threshold below which |p(x, ξ)| is considered effectively zero; used for numerical stability. density : int Number of grid points along each axis for visualization resolution.

Raises

NotImplementedError If called on a solver with dimension greater than 1 (only 1D visualization is supported).

Notes

  • This method evaluates the symbol p(x, ξ) over a grid and plots its reciprocal to emphasize regions where the symbol is near zero.
  • A small constant (1e-10) is added to the denominator to avoid division by zero.
  • The resulting plot helps identify characteristic sets.
def group_velocity_field(self, xlim=(-2, 2), klim=(-10, 10), density=30):
2442    def group_velocity_field(self, xlim=(-2, 2), klim=(-10, 10), density=30):
2443        """
2444        Plot the group velocity field ∇_ξ p(x, ξ) for 1D pseudo-differential operators.
2445
2446        The group velocity represents the speed at which waves of different frequencies propagate 
2447        in a dispersive medium. It is defined as the gradient of the symbol p(x, ξ) with respect 
2448        to the frequency variable ξ.
2449
2450        Parameters
2451        ----------
2452        xlim : tuple of float
2453            Spatial domain limits (x-axis).
2454        klim : tuple of float
2455            Frequency domain limits (ξ-axis).
2456        density : int
2457            Number of grid points per axis used for visualization.
2458
2459        Raises
2460        ------
2461        NotImplementedError
2462            If called on a 2D operator, since this visualization is only implemented for 1D.
2463
2464        Notes
2465        -----
2466        - This method visualizes the vector field (∂p/∂ξ) in phase space.
2467        - Used for analyzing wave propagation properties and dispersion relations.
2468        - Requires symbolic expression self.expr depending on x and ξ.
2469        """
2470        if self.dim != 1:
2471            raise NotImplementedError("Only 1D group velocity visualization implemented.")
2472
2473        x, = self.vars_x
2474        xi = symbols('xi', real=True)
2475        dp_dxi = diff(self.symbol, xi)
2476        grad_func = lambdify((x, xi), dp_dxi, 'numpy')
2477
2478        x_vals = np.linspace(*xlim, density)
2479        xi_vals = np.linspace(*klim, density)
2480        X, XI = np.meshgrid(x_vals, xi_vals, indexing='ij')
2481        V = grad_func(X, XI)
2482
2483        plt.quiver(X, XI, np.ones_like(V), V, scale=10, width=0.004)
2484        plt.xlabel('x')
2485        plt.ylabel(r'$\xi$')
2486        plt.title("Group Velocity Field (1D)")
2487        plt.grid(True)
2488        plt.show()

Plot the group velocity field ∇_ξ p(x, ξ) for 1D pseudo-differential operators.

The group velocity represents the speed at which waves of different frequencies propagate in a dispersive medium. It is defined as the gradient of the symbol p(x, ξ) with respect to the frequency variable ξ.

Parameters

xlim : tuple of float Spatial domain limits (x-axis). klim : tuple of float Frequency domain limits (ξ-axis). density : int Number of grid points per axis used for visualization.

Raises

NotImplementedError If called on a 2D operator, since this visualization is only implemented for 1D.

Notes

  • This method visualizes the vector field (∂p/∂ξ) in phase space.
  • Used for analyzing wave propagation properties and dispersion relations.
  • Requires symbolic expression self.expr depending on x and ξ.
def animate_singularity( self, xi0=5.0, eta0=0.0, x0=0.0, y0=0.0, tmax=4.0, n_frames=100, projection=None):
2490    def animate_singularity(self, xi0=5.0, eta0=0.0, x0=0.0, y0=0.0,
2491                            tmax=4.0, n_frames=100, projection=None):
2492        """
2493        Animate the propagation of a singularity under the Hamiltonian flow.
2494
2495        This method visualizes how a singularity (x₀, y₀, ξ₀, η₀) evolves in phase space 
2496        according to the Hamiltonian dynamics induced by the principal symbol of the operator.
2497        The animation integrates the Hamiltonian equations of motion and supports various projections:
2498        position (x-y), frequency (ξ-η), or mixed phase space coordinates.
2499
2500        Parameters
2501        ----------
2502        xi0, eta0 : float
2503            Initial frequency components (ξ₀, η₀).
2504        x0, y0 : float
2505            Initial spatial coordinates (x₀, y₀).
2506        tmax : float
2507            Total time of integration (final animation time).
2508        n_frames : int
2509            Number of frames in the resulting animation.
2510        projection : str or None
2511            Type of projection to display:
2512                - 'position' : x vs y (or x alone in 1D)
2513                - 'frequency': ξ vs η (or ξ alone in 1D)
2514                - 'phase'    : mixed coordinates like x vs ξ or x vs η
2515                If None, defaults to 'phase' in 1D and 'position' in 2D.
2516
2517        Returns
2518        -------
2519        matplotlib.animation.FuncAnimation
2520            Animation object that can be displayed interactively in Jupyter notebooks or saved as a video.
2521
2522        Notes
2523        -----
2524        - In 1D, only one spatial and one frequency variable are used.
2525        - Complex-valued Hamiltonian fields are truncated to their real parts for integration.
2526        - Trajectories are shown with both instantaneous position (dot) and full path (dashed line).
2527        """
2528        rc('animation', html='jshtml')
2529    
2530        def make_real(expr):
2531            from sympy import re, simplify
2532            expr = expr.doit(deep=True)
2533            return simplify(re(expr))
2534  
2535        H = self.symplectic_flow()
2536
2537        H = {k: v.doit(deep=True) for k, v in H.items()}
2538
2539        print("H = ", H)
2540    
2541        if any(im(H[k]) != 0 for k in H):
2542            print("⚠️ The Hamiltonian field is complex. Only the real part is used for integration.")
2543    
2544        if self.dim == 1:
2545            x, = self.vars_x
2546            xi = symbols('xi', real=True)
2547    
2548            dxdt = lambdify((x, xi), make_real(H['dx/dt']), 'numpy')
2549            dxidt = lambdify((x, xi), make_real(H['dxi/dt']), 'numpy')
2550    
2551            def hamilton(t, Y):
2552                x, xi = Y
2553                return [dxdt(x, xi), dxidt(x, xi)]
2554    
2555            sol = solve_ivp(hamilton, [0, tmax], [x0, xi0],
2556                            t_eval=np.linspace(0, tmax, n_frames))
2557            
2558            if sol.status != 0:
2559                print(f"⚠️ Integration warning: {sol.message}")
2560            
2561            n_points = sol.y.shape[1]
2562            if n_points < n_frames:
2563                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2564                n_frames = n_points
2565
2566            x_vals, xi_vals = sol.y
2567    
2568            if projection is None:
2569                projection = 'phase'
2570    
2571            fig, ax = plt.subplots()
2572            point, = ax.plot([], [], 'ro')
2573            traj, = ax.plot([], [], 'b--', lw=1, alpha=0.5)
2574    
2575            if projection == 'phase':
2576                ax.set_xlabel('x')
2577                ax.set_ylabel(r'$\xi$')
2578                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2579                ax.set_ylim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2580    
2581                def update(i):
2582                    point.set_data([x_vals[i]], [xi_vals[i]])
2583                    traj.set_data(x_vals[:i+1], xi_vals[:i+1])
2584                    return point, traj
2585    
2586            elif projection == 'position':
2587                ax.set_xlabel('x')
2588                ax.set_ylabel('x')
2589                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2590                ax.set_ylim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2591    
2592                def update(i):
2593                    point.set_data([x_vals[i]], [x_vals[i]])
2594                    traj.set_data(x_vals[:i+1], x_vals[:i+1])
2595                    return point, traj
2596    
2597            elif projection == 'frequency':
2598                ax.set_xlabel(r'$\xi$')
2599                ax.set_ylabel(r'$\xi$')
2600                ax.set_xlim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2601                ax.set_ylim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2602    
2603                def update(i):
2604                    point.set_data([xi_vals[i]], [xi_vals[i]])
2605                    traj.set_data(xi_vals[:i+1], xi_vals[:i+1])
2606                    return point, traj
2607    
2608            else:
2609                raise ValueError("Invalid projection mode")
2610    
2611            ax.set_title(f"1D Singularity Flow ({projection})")
2612            ax.grid(True)
2613            ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50)
2614            plt.close(fig)
2615            return ani
2616    
2617        elif self.dim == 2:
2618            x, y = self.vars_x
2619            xi, eta = symbols('xi eta', real=True)
2620    
2621            dxdt = lambdify((x, y, xi, eta), make_real(H['dx/dt']), 'numpy')
2622            dydt = lambdify((x, y, xi, eta), make_real(H['dy/dt']), 'numpy')
2623            dxidt = lambdify((x, y, xi, eta), make_real(H['dxi/dt']), 'numpy')
2624            detadt = lambdify((x, y, xi, eta), make_real(H['deta/dt']), 'numpy')
2625    
2626            def hamilton(t, Y):
2627                x, y, xi, eta = Y
2628                return [
2629                    dxdt(x, y, xi, eta),
2630                    dydt(x, y, xi, eta),
2631                    dxidt(x, y, xi, eta),
2632                    detadt(x, y, xi, eta)
2633                ]
2634    
2635            sol = solve_ivp(hamilton, [0, tmax], [x0, y0, xi0, eta0],
2636                            t_eval=np.linspace(0, tmax, n_frames))
2637
2638            if sol.status != 0:
2639                print(f"⚠️ Integration warning: {sol.message}")
2640            
2641            n_points = sol.y.shape[1]
2642            if n_points < n_frames:
2643                print(f"⚠️ Only {n_points} frames computed. Adjusting animation.")
2644                n_frames = n_points
2645                
2646            x_vals, y_vals, xi_vals, eta_vals = sol.y
2647    
2648            if projection is None:
2649                projection = 'position'
2650    
2651            fig, ax = plt.subplots()
2652            point, = ax.plot([], [], 'ro')
2653            traj, = ax.plot([], [], 'b--', lw=1, alpha=0.5)
2654    
2655            if projection == 'position':
2656                ax.set_xlabel('x')
2657                ax.set_ylabel('y')
2658                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2659                ax.set_ylim(np.min(y_vals) - 1, np.max(y_vals) + 1)
2660    
2661                def update(i):
2662                    point.set_data([x_vals[i]], [y_vals[i]])
2663                    traj.set_data(x_vals[:i+1], y_vals[:i+1])
2664                    return point, traj
2665    
2666            elif projection == 'frequency':
2667                ax.set_xlabel(r'$\xi$')
2668                ax.set_ylabel(r'$\eta$')
2669                ax.set_xlim(np.min(xi_vals) - 1, np.max(xi_vals) + 1)
2670                ax.set_ylim(np.min(eta_vals) - 1, np.max(eta_vals) + 1)
2671    
2672                def update(i):
2673                    point.set_data([xi_vals[i]], [eta_vals[i]])
2674                    traj.set_data(xi_vals[:i+1], eta_vals[:i+1])
2675                    return point, traj
2676    
2677            elif projection == 'phase':
2678                ax.set_xlabel('x')
2679                ax.set_ylabel(r'$\eta$')
2680                ax.set_xlim(np.min(x_vals) - 1, np.max(x_vals) + 1)
2681                ax.set_ylim(np.min(eta_vals) - 1, np.max(eta_vals) + 1)
2682    
2683                def update(i):
2684                    point.set_data([x_vals[i]], [eta_vals[i]])
2685                    traj.set_data(x_vals[:i+1], eta_vals[:i+1])
2686                    return point, traj
2687    
2688            else:
2689                raise ValueError("Invalid projection mode")
2690    
2691            ax.set_title(f"2D Singularity Flow ({projection})")
2692            ax.grid(True)
2693            ax.axis('equal')
2694            ani = animation.FuncAnimation(fig, update, frames=n_frames, interval=50)
2695            plt.close(fig)
2696            return ani

Animate the propagation of a singularity under the Hamiltonian flow.

This method visualizes how a singularity (x₀, y₀, ξ₀, η₀) evolves in phase space according to the Hamiltonian dynamics induced by the principal symbol of the operator. The animation integrates the Hamiltonian equations of motion and supports various projections: position (x-y), frequency (ξ-η), or mixed phase space coordinates.

Parameters

xi0, eta0 : float Initial frequency components (ξ₀, η₀). x0, y0 : float Initial spatial coordinates (x₀, y₀). tmax : float Total time of integration (final animation time). n_frames : int Number of frames in the resulting animation. projection : str or None Type of projection to display: - 'position' : x vs y (or x alone in 1D) - 'frequency': ξ vs η (or ξ alone in 1D) - 'phase' : mixed coordinates like x vs ξ or x vs η If None, defaults to 'phase' in 1D and 'position' in 2D.

Returns

matplotlib.animation.FuncAnimation Animation object that can be displayed interactively in Jupyter notebooks or saved as a video.

Notes

  • In 1D, only one spatial and one frequency variable are used.
  • Complex-valued Hamiltonian fields are truncated to their real parts for integration.
  • Trajectories are shown with both instantaneous position (dot) and full path (dashed line).
def interactive_symbol_analysis( pseudo_op, xlim=(-2, 2), ylim=(-2, 2), xi_range=(0.1, 5), eta_range=(-5, 5), density=100):
2698    def interactive_symbol_analysis(pseudo_op,
2699                                    xlim=(-2, 2), ylim=(-2, 2),
2700                                    xi_range=(0.1, 5), eta_range=(-5, 5),
2701                                    density=100):
2702        """
2703        Launch an interactive dashboard for symbol exploration using ipywidgets.
2704    
2705        This function provides a user-friendly interface to visualize various aspects of the pseudo-differential operator's symbol.
2706        It supports multiple visualization modes in both 1D and 2D, including group velocity fields, micro-support estimates,
2707        symplectic vector fields, symbol amplitude/phase, cotangent fiber structure, characteristic sets and Hamiltonian flows.
2708    
2709        Parameters
2710        ----------
2711        pseudo_op : PseudoDifferentialOperator
2712            The pseudo-differential operator whose symbol is to be analyzed interactively.
2713        xlim, ylim : tuple of float
2714            Spatial domain limits along x and y axes respectively.
2715        xi_range, eta_range : tuple
2716            Frequency domain limits along ξ and η axes respectively.
2717        density : int
2718            Number of points per axis used to construct the evaluation grid. Controls resolution.
2719    
2720        Notes
2721        -----
2722        - In 1D mode, sliders control the fixed frequency (ξ₀) and spatial position (x₀).
2723        - In 2D mode, additional sliders control the second frequency component (η₀) and second spatial coordinate (y₀).
2724        - Visualization updates dynamically as parameters are adjusted via sliders or dropdown menus.
2725        - Supported visualization modes:
2726            'Symbol Amplitude'           : |p(x,ξ)| or |p(x,y,ξ,η)|
2727            'Symbol Phase'               : arg(p(x,ξ)) or similar in 2D
2728            'Micro-Support (1/|p|)'      : Reciprocal of symbol magnitude
2729            'Cotangent Fiber'            : Structure of symbol over frequency space at fixed x
2730            'Characteristic Set'         : Zero set approximation {p ≈ 0}
2731            'Characteristic Gradient'    : |∇p(x, ξ)| or |∇p(x₀, y₀, ξ, η)|
2732            'Group Velocity Field'       : ∇_ξ p(x,ξ) or ∇_{ξ,η} p(x,y,ξ,η)
2733            'Symplectic Vector Field'    : (∇_ξ p, -∇_x p) or similar in 2D
2734            'Hamiltonian Flow'           : Trajectories generated by the Hamiltonian vector field
2735    
2736        Raises
2737        ------
2738        NotImplementedError
2739            If the spatial dimension is not 1D or 2D.
2740    
2741        Prints
2742        ------
2743        Interactive matplotlib figures with dynamic updates based on widget inputs.
2744        """
2745        dim = pseudo_op.dim
2746        expr = pseudo_op.expr
2747        vars_x = pseudo_op.vars_x
2748    
2749        mode_selector_1D = Dropdown(
2750            options=[
2751                'Symbol Amplitude',
2752                'Symbol Phase',
2753                'Micro-Support (1/|p|)',
2754                'Cotangent Fiber',
2755                'Characteristic Set',
2756                'Characteristic Gradient',
2757                'Group Velocity Field',
2758                'Symplectic Vector Field',
2759                'Hamiltonian Flow',
2760            ],
2761            value='Symbol Amplitude',
2762            description='Mode:'
2763        )
2764
2765        mode_selector_2D = Dropdown(
2766            options=[
2767                'Symbol Amplitude',
2768                'Symbol Phase',
2769                'Micro-Support (1/|p|)',
2770                'Cotangent Fiber',
2771                'Characteristic Set',
2772                'Characteristic Gradient',
2773                'Symplectic Vector Field',
2774                'Hamiltonian Flow',
2775            ],
2776            value='Symbol Amplitude',
2777            description='Mode:'
2778        )
2779    
2780        x_vals = np.linspace(*xlim, density)
2781        if dim == 2:
2782            y_vals = np.linspace(*ylim, density)
2783    
2784        if dim == 1:
2785            x, = vars_x
2786            xi = symbols('xi', real=True)
2787            grad_func = lambdify((x, xi), diff(expr, xi), 'numpy')
2788            symplectic_func = lambdify((x, xi), [diff(expr, xi), -diff(expr, x)], 'numpy')
2789            symbol_func = lambdify((x, xi), expr, 'numpy')
2790
2791            xi_slider = FloatSlider(min=xi_range[0], max=xi_range[1], step=0.1, value=1.0, description='ξ₀')
2792            x_slider = FloatSlider(min=xlim[0], max=xlim[1], step=0.1, value=0.0, description='x₀')
2793    
2794            def plot_1d(mode, xi0, x0):
2795                X = x_vals[:, None]
2796    
2797                if mode == 'Group Velocity Field':
2798                    V = grad_func(X, xi0)
2799                    plt.quiver(X, V, np.ones_like(V), V, scale=10, width=0.004)
2800                    plt.xlabel('x')
2801                    plt.title(f'Group Velocity Field at ξ={xi0:.2f}')
2802    
2803                elif mode == 'Micro-Support (1/|p|)':
2804                    Z = 1 / (np.abs(symbol_func(X, xi0)) + 1e-10)
2805                    plt.plot(x_vals, Z)
2806                    plt.xlabel('x')
2807                    plt.title(f'Micro-Support (1/|p|) at ξ={xi0:.2f}')
2808    
2809                elif mode == 'Symplectic Vector Field':
2810                    U, V = symplectic_func(X, xi0)
2811                    plt.quiver(X, V, U, V, scale=10, width=0.004)
2812                    plt.xlabel('x')
2813                    plt.title(f'Symplectic Field at ξ={xi0:.2f}')
2814    
2815                elif mode == 'Symbol Amplitude':
2816                    Z = np.abs(symbol_func(X, xi0))
2817                    plt.plot(x_vals, Z)
2818                    plt.xlabel('x')
2819                    plt.title(f'Symbol Amplitude |p(x,ξ)| at ξ={xi0:.2f}')
2820    
2821                elif mode == 'Symbol Phase':
2822                    Z = np.angle(symbol_func(X, xi0))
2823                    plt.plot(x_vals, Z)
2824                    plt.xlabel('x')
2825                    plt.title(f'Symbol Phase arg(p(x,ξ)) at ξ={xi0:.2f}')
2826    
2827                elif mode == 'Cotangent Fiber':
2828                    pseudo_op.visualize_fiber(x_vals, np.linspace(*xi_range, density), x0=x0)
2829    
2830                elif mode == 'Characteristic Set':
2831                    pseudo_op.visualize_characteristic_set(x_vals, np.linspace(*xi_range, density), x0=x0)
2832    
2833                elif mode == 'Characteristic Gradient':
2834                    pseudo_op.visualize_characteristic_gradient(x_vals, np.linspace(*xi_range, density), x0=x0)
2835    
2836                elif mode == 'Hamiltonian Flow':
2837                    pseudo_op.plot_hamiltonian_flow(x0=x0, xi0=xi0)
2838    
2839            # --- Dynamic container for sliders ---
2840            controls_box = VBox([mode_selector_1D, xi_slider, x_slider])
2841            # --- Function to adjust visible sliders based on mode ---
2842            def update_controls(change):
2843                mode = change['new']
2844                # modes that depend only on xi and eta
2845                if mode in ['Symbol Amplitude', 'Symbol Phase', 'Micro-Support (1/|p|)',
2846                            'Group Velocity Field', 'Symplectic Vector Field']:
2847                    controls_box.children = [mode_selector_1D, xi_slider]
2848                # modes that require xi and x
2849                elif mode in ['Hamiltonian Flow']:
2850                    controls_box.children = [mode_selector_1D, xi_slider, x_slider]
2851                # modes that require nothing
2852                elif mode in ['Cotangent Fiber', 'Characteristic Set', 'Characteristic Gradient']:
2853                    controls_box.children = [mode_selector_1D]
2854            mode_selector_1D.observe(update_controls, names='value')
2855            update_controls({'new': mode_selector_1D.value}) 
2856            # --- Interactive binding ---
2857            out = interactive_output(plot_1d, {'mode': mode_selector_1D, 'xi0': xi_slider, 'x0': x_slider})
2858            display(VBox([controls_box, out]))
2859
2860        elif dim == 2:
2861            x, y = vars_x
2862            xi, eta = symbols('xi eta', real=True)
2863            symplectic_func = lambdify((x, y, xi, eta), [diff(expr, xi), diff(expr, eta)], 'numpy')
2864            symbol_func = lambdify((x, y, xi, eta), expr, 'numpy')
2865
2866            xi_slider=FloatSlider(min=xi_range[0], max=xi_range[1], step=0.1, value=1.0, description='ξ₀')
2867            eta_slider=FloatSlider(min=eta_range[0], max=eta_range[1], step=0.1, value=1.0, description='η₀')
2868            x_slider=FloatSlider(min=xlim[0], max=xlim[1], step=0.1, value=0.0, description='x₀')
2869            y_slider=FloatSlider(min=ylim[0], max=ylim[1], step=0.1, value=0.0, description='y₀')
2870    
2871            def plot_2d(mode, xi0, eta0, x0, y0):
2872                X, Y = np.meshgrid(x_vals, y_vals, indexing='ij')
2873    
2874                if mode == 'Micro-Support (1/|p|)':
2875                    Z = 1 / (np.abs(symbol_func(X, Y, xi0, eta0)) + 1e-10)
2876                    plt.pcolormesh(X, Y, Z, shading='auto', cmap='inferno')
2877                    plt.colorbar(label='1/|p|')
2878                    plt.xlabel('x')
2879                    plt.ylabel('y')
2880                    plt.title(f'Micro-Support at ξ={xi0:.2f}, η={eta0:.2f}')
2881    
2882                elif mode == 'Symplectic Vector Field':
2883                    U, V = symplectic_func(X, Y, xi0, eta0)
2884                    plt.quiver(X, Y, U, V, scale=10, width=0.004)
2885                    plt.xlabel('x')
2886                    plt.ylabel('y')
2887                    plt.title(f'Symplectic Field at ξ={xi0:.2f}, η={eta0:.2f}')
2888    
2889                elif mode == 'Symbol Amplitude':
2890                    Z = np.abs(symbol_func(X, Y, xi0, eta0))
2891                    plt.pcolormesh(X, Y, Z, shading='auto')
2892                    plt.colorbar(label='|p(x,y,ξ,η)|')
2893                    plt.xlabel('x')
2894                    plt.ylabel('y')
2895                    plt.title(f'Symbol Amplitude at ξ={xi0:.2f}, η={eta0:.2f}')
2896    
2897                elif mode == 'Symbol Phase':
2898                    Z = np.angle(symbol_func(X, Y, xi0, eta0))
2899                    plt.pcolormesh(X, Y, Z, shading='auto', cmap='twilight')
2900                    plt.colorbar(label='arg(p)')
2901                    plt.xlabel('x')
2902                    plt.ylabel('y')
2903                    plt.title(f'Symbol Phase at ξ={xi0:.2f}, η={eta0:.2f}')
2904    
2905                elif mode == 'Cotangent Fiber':
2906                    pseudo_op.visualize_fiber(np.linspace(*xi_range, density), np.linspace(*eta_range, density),
2907                                              x0=x0, y0=y0)
2908    
2909                elif mode == 'Characteristic Set':
2910                    pseudo_op.visualize_characteristic_set(x_grid=x_vals, xi_grid=np.linspace(*xi_range, density),
2911                                                  y_grid=y_vals, eta_grid=np.linspace(*eta_range, density), x0=x0, y0=y0)
2912    
2913                elif mode == 'Characteristic Gradient':
2914                    pseudo_op.visualize_characteristic_gradient(x_grid=x_vals, xi_grid=np.linspace(*xi_range, density),
2915                                                  y_grid=y_vals, eta_grid=np.linspace(*eta_range, density), x0=x0, y0=y0)
2916    
2917                elif mode == 'Hamiltonian Flow':
2918                    pseudo_op.plot_hamiltonian_flow(x0=x0, y0=y0, xi0=xi0, eta0=eta0)
2919                    
2920            # --- Dynamic container for sliders ---
2921            controls_box = VBox([mode_selector_2D, xi_slider, eta_slider, x_slider, y_slider])
2922            # --- Function to adjust visible sliders based on mode ---
2923            def update_controls(change):
2924                mode = change['new']
2925                # modes that depend only on xi
2926                if mode in ['Symbol Amplitude', 'Symbol Phase', 'Micro-Support (1/|p|)', 'Symplectic Vector Field']:
2927                    controls_box.children = [mode_selector_2D, xi_slider, eta_slider]
2928                # modes that require xi, eta, x and y
2929                elif mode in ['Hamiltonian Flow']:
2930                    controls_box.children = [mode_selector_2D, xi_slider, eta_slider, x_slider, y_slider]
2931                # modes that require x and y
2932                elif mode in ['Cotangent Fiber', 'Characteristic Set', 'Characteristic Gradient']:
2933                    controls_box.children = [mode_selector_2D, x_slider, y_slider]
2934            mode_selector_2D.observe(update_controls, names='value')
2935            update_controls({'new': mode_selector_2D.value}) 
2936            # --- Interactive binding ---
2937            out = interactive_output(plot_2d, {'mode': mode_selector_2D, 'xi0': xi_slider, 'eta0': eta_slider, 'x0': x_slider, 'y0': y_slider})
2938            display(VBox([controls_box, out]))

Launch an interactive dashboard for symbol exploration using ipywidgets.

This function provides a user-friendly interface to visualize various aspects of the pseudo-differential operator's symbol. It supports multiple visualization modes in both 1D and 2D, including group velocity fields, micro-support estimates, symplectic vector fields, symbol amplitude/phase, cotangent fiber structure, characteristic sets and Hamiltonian flows.

Parameters

pseudo_op : PseudoDifferentialOperator The pseudo-differential operator whose symbol is to be analyzed interactively. xlim, ylim : tuple of float Spatial domain limits along x and y axes respectively. xi_range, eta_range : tuple Frequency domain limits along ξ and η axes respectively. density : int Number of points per axis used to construct the evaluation grid. Controls resolution.

Notes

  • In 1D mode, sliders control the fixed frequency (ξ₀) and spatial position (x₀).
  • In 2D mode, additional sliders control the second frequency component (η₀) and second spatial coordinate (y₀).
  • Visualization updates dynamically as parameters are adjusted via sliders or dropdown menus.
  • Supported visualization modes: 'Symbol Amplitude' : |p(x,ξ)| or |p(x,y,ξ,η)| 'Symbol Phase' : arg(p(x,ξ)) or similar in 2D 'Micro-Support (1/|p|)' : Reciprocal of symbol magnitude 'Cotangent Fiber' : Structure of symbol over frequency space at fixed x 'Characteristic Set' : Zero set approximation {p ≈ 0} 'Characteristic Gradient' : |∇p(x, ξ)| or |∇p(x₀, y₀, ξ, η)| 'Group Velocity Field' : ∇_ξ p(x,ξ) or ∇_{ξ,η} p(x,y,ξ,η) 'Symplectic Vector Field' : (∇_ξ p, -∇_x p) or similar in 2D 'Hamiltonian Flow' : Trajectories generated by the Hamiltonian vector field

Raises

NotImplementedError If the spatial dimension is not 1D or 2D.

Prints

Interactive matplotlib figures with dynamic updates based on widget inputs.

class PDESolver:
  28class PDESolver:
  29    """
  30    A partial differential equation (PDE) solver based on **spectral methods** using Fourier transforms.
  31
  32    This solver supports symbolic specification of PDEs via SymPy and numerical solution using high-order spectral techniques. 
  33    It is designed for both **linear and nonlinear time-dependent PDEs**, as well as **stationary pseudo-differential problems**.
  34    
  35    Key Features:
  36    -------------
  37    - Symbolic PDE parsing using SymPy expressions
  38    - 1D and 2D spatial domains with periodic boundary conditions
  39    - Fourier-based spectral discretization with dealiasing
  40    - Temporal integration schemes:
  41        - Default exponential time stepping
  42        - ETD-RK4 (Exponential Time Differencing Runge-Kutta of 4th order)
  43    - Nonlinear terms handled through pseudo-spectral evaluation
  44    - Built-in tools for:
  45        - Visualization of solutions and error surfaces
  46        - Symbol analysis of linear and pseudo-differential operators
  47        - Microlocal analysis (e.g., Hamiltonian flows)
  48        - CFL condition checking and numerical stability diagnostics
  49
  50    Supported Operators:
  51    --------------------
  52    - Linear differential and pseudo-differential operators
  53    - Nonlinear terms up to second order in derivatives
  54    - Symbolic operator composition and adjoints
  55    - Asymptotic inversion of elliptic operators for stationary problems
  56
  57    Example Usage:
  58    --------------
  59    >>> from PDESolver import *
  60    >>> u = Function('u')
  61    >>> t, x = symbols('t x')
  62    >>> eq = Eq(diff(u(t, x), t), diff(u(t, x), x, 2) + u(t, x)**2)
  63    >>> def _initial(x): return np.sin(x)
  64    >>> solver = PDESolver(eq)
  65    >>> solver.setup(Lx=2*np.pi, Nx=128, Lt=1.0, Nt=1000, initial_condition=initial)
  66    >>> solver.solve()
  67    >>> ani = solver.animate()
  68    >>> HTML(ani.to_jshtml())  # Display animation in Jupyter notebook
  69    """
  70    def __init__(self, equation, time_scheme='default', dealiasing_ratio=2/3):
  71        """
  72        Initialize the PDE solver with a given equation.
  73
  74        This method analyzes the input partial differential equation (PDE), 
  75        identifies the unknown function and its dependencies, determines whether 
  76        the problem is stationary or time-dependent, and prepares symbolic and 
  77        numerical structures for solving in spectral space.
  78
  79        Supported features:
  80        
  81        - 1D and 2D problems
  82        - Time-dependent and stationary equations
  83        - Linear and nonlinear terms
  84        - Pseudo-differential operators via `psiOp`
  85        - Source terms and boundary conditions
  86
  87        The equation is parsed to extract linear, nonlinear, source, and 
  88        pseudo-differential components. Symbolic manipulation is used to derive 
  89        the Fourier representation of linear operators when applicable.
  90
  91        Parameters
  92        ----------
  93        equation : sympy.Eq 
  94            The PDE expressed as a SymPy equation.
  95        time_scheme : str
  96            Temporal integration scheme: 
  97                - 'default' for exponential 
  98                - time-stepping or 'ETD-RK4' for fourth-order exponential 
  99                - time differencing Runge–Kutta.
 100        dealiasing_ratio : float
 101            Fraction of high-frequency modes to zero out 
 102            during dealiasing (e.g., 2/3 for standard truncation).
 103
 104        Attributes initialized:
 105        
 106        - self.u: the unknown function (e.g., u(t, x))
 107        - self.dim: spatial dimension (1 or 2)
 108        - self.spatial_vars: list of spatial variables (e.g., [x] or [x, y])
 109        - self.is_stationary: boolean indicating if the problem is stationary
 110        - self.linear_terms: dictionary mapping derivative orders to coefficients
 111        - self.nonlinear_terms: list of nonlinear expressions
 112        - self.source_terms: list of source functions
 113        - self.pseudo_terms: list of pseudo-differential operator expressions
 114        - self.has_psi: boolean indicating presence of pseudo-differential operators
 115        - self.fft / self.ifft: appropriate FFT routines based on spatial dimension
 116        - self.kx, self.ky: symbolic wavenumber variables for Fourier space
 117
 118        Raises:
 119            ValueError: If the equation does not contain exactly one unknown function,
 120                        if unsupported dimensions are detected, or invalid dependencies.
 121        """
 122        self.time_scheme = time_scheme # 'default'  or 'ETD-RK4'
 123        self.dealiasing_ratio = dealiasing_ratio
 124        
 125        print("\n*********************************")
 126        print("* Partial differential equation *")
 127        print("*********************************\n")
 128        pprint(equation, num_columns=NUM_COLS)
 129        
 130        # Extract symbols and function from the equation
 131        functions = equation.atoms(Function)
 132        
 133        # Ignore the wrappers psiOp and Op
 134        excluded_wrappers = {'psiOp', 'Op'}
 135        
 136        # Extract the candidate fonctions (excluding wrappers)
 137        candidate_functions = [
 138            f for f in functions 
 139            if f.func.__name__ not in excluded_wrappers
 140        ]
 141        
 142        # Keep only user functions (u(x), u(x, t), etc.)
 143        candidate_functions = [
 144            f for f in functions
 145            if isinstance(f, AppliedUndef)
 146        ]
 147        
 148        # Stationary detection: no dependence on t
 149        self.is_stationary = all(
 150            not any(str(arg) == 't' for arg in f.args)
 151            for f in candidate_functions
 152        )
 153        
 154        if len(candidate_functions) != 1:
 155            print("candidate_functions :", candidate_functions)
 156            raise ValueError("The equation must contain exactly one unknown function")
 157        
 158        self.u = candidate_functions[0]
 159
 160        self.u_eq = self.u
 161
 162        args = self.u.args
 163        
 164        if self.is_stationary:
 165            if len(args) not in (1, 2):
 166                raise ValueError("Stationary problems must depend on 1 or 2 spatial variables")
 167            self.spatial_vars = args
 168        else:
 169            if len(args) < 2 or len(args) > 3:
 170                raise ValueError("The function must depend on t and at least one spatial variable (x [, y])")
 171            self.t = args[0]
 172            self.spatial_vars = args[1:]
 173
 174        self.dim = len(self.spatial_vars)
 175        if self.dim == 1:
 176            self.x = self.spatial_vars[0]
 177            self.y = None
 178        elif self.dim == 2:
 179            self.x, self.y = self.spatial_vars
 180        else:
 181            raise ValueError("Only 1D and 2D problems are supported.")
 182
 183        if self.dim == 1:
 184            self.fft = partial(fft, workers=FFT_WORKERS)
 185            self.ifft = partial(ifft, workers=FFT_WORKERS)
 186        else:
 187            self.fft = partial(fft2, workers=FFT_WORKERS)
 188            self.ifft = partial(ifft2, workers=FFT_WORKERS)
 189            
 190        # Parse the equation
 191        self.linear_terms = {}
 192        self.nonlinear_terms = []
 193        self.symbol_terms = []
 194        self.source_terms = []
 195        self.pseudo_terms = []
 196        self.temporal_order = 0  # Order of the temporal derivative
 197        self.linear_terms, self.nonlinear_terms, self.symbol_terms, self.source_terms, self.pseudo_terms = self._parse_equation(equation)
 198        # flag : pseudo‑differential operator present ?
 199        self.has_psi = bool(self.pseudo_terms)
 200        if self.has_psi:
 201            print('⚠️  Pseudo‑differential operator detected: all other linear terms have been rejected.')
 202            self.is_spatial = False
 203            for coeff, expr in self.pseudo_terms:
 204                if expr.has(self.x) or (self.dim == 2 and expr.has(self.y)):
 205                    self.is_spatial = True
 206                    break
 207    
 208        if self.dim == 1:
 209            self.kx = symbols('kx')
 210        elif self.dim == 2:
 211            self.kx, self.ky = symbols('kx ky')
 212    
 213        # Compute linear operator
 214        if not self.is_stationary:
 215            self._compute_linear_operator()
 216        else:
 217            self.psi_ops = []
 218            for coeff, sym_expr in self.pseudo_terms:
 219                psi = PseudoDifferentialOperator(sym_expr, self.spatial_vars, self.u, mode='symbol')
 220                self.psi_ops.append((coeff, psi))
 221
 222    def _parse_equation(self, equation):
 223        """
 224        Parse the PDE to separate linear and nonlinear terms, symbolic operators (Op), 
 225        source terms, and pseudo-differential operators (psiOp).
 226    
 227        This method rewrites the input equation in standard form (lhs - rhs = 0),
 228        expands it, and classifies each term into one of the following categories:
 229        
 230        - Linear terms involving derivatives or the unknown function u
 231        - Nonlinear terms (products with u, powers of u, etc.)
 232        - Symbolic pseudo-differential operators (Op)
 233        - Source terms (independent of u)
 234        - Pseudo-differential operators (psiOp)
 235    
 236        Parameters
 237            equation (sympy.Eq): The partial differential equation to be analyzed. 
 238                                 Can be provided as an Eq object or a sympy expression.
 239    
 240        Returns:
 241            tuple: A 5-tuple containing:
 242            
 243                - linear_terms (dict): Mapping from derivative/function to coefficient.
 244                - nonlinear_terms (list): List of terms classified as nonlinear.
 245                - symbol_terms (list): List of (coefficient, symbolic operator) pairs.
 246                - source_terms (list): List of terms independent of the unknown function.
 247                - pseudo_terms (list): List of (coefficient, pseudo-differential symbol) pairs.
 248    
 249        Notes:
 250            - If `psiOp` is present in the equation, expansion is skipped for safety.
 251            - When `psiOp` is used, only nonlinear terms, source terms, and possibly 
 252              a time derivative are allowed; other linear terms and symbolic operators 
 253              (Op) are forbidden.
 254            - Classification logic includes:
 255                - Detection of nonlinear structures like products or powers of u
 256                - Mixed terms involving both u and its derivatives
 257                - External symbolic operators (Op) and pseudo-differential operators (psiOp)
 258        """
 259        def _is_nonlinear_term(term, u_func):
 260            # If the term contains functions (Abs, sin, exp, ...) applied to u
 261            if term.has(u_func):
 262                for sub in preorder_traversal(term):
 263                    if isinstance(sub, Function) and sub.has(u_func) and sub.func != u_func.func:
 264                        return True
 265            # If the term contains a nonlinear power of u
 266            if term.has(Pow):
 267                for pow_term in term.atoms(Pow):
 268                    if pow_term.base == u_func and pow_term.exp != 1:
 269                        return True
 270            # If the term is a product containing u and its derivative
 271            if term.func == Mul:
 272                factors = term.args
 273                has_u = any((f.has(u_func) and not isinstance(f, Derivative) for f in factors))
 274                has_derivative = any((isinstance(f, Derivative) and f.expr.func == u_func.func for f in factors))
 275                if has_u and has_derivative:
 276                    return True
 277            return False
 278    
 279        print("\n********************")
 280        print("* Equation parsing *")
 281        print("********************\n")
 282    
 283        if isinstance(equation, Eq):
 284            lhs = equation.lhs - equation.rhs
 285        else:
 286            lhs = equation
 287    
 288        print(f"\nEquation rewritten in standard form: {lhs}")
 289        if lhs.has(psiOp):
 290            print("⚠️ psiOp detected: skipping expansion for safety")
 291            lhs_expanded = lhs
 292        else:
 293            lhs_expanded = expand(lhs)
 294    
 295        print(f"\nExpanded equation: {lhs_expanded}")
 296    
 297        linear_terms = {}
 298        nonlinear_terms = []
 299        symbol_terms = []
 300        source_terms = []
 301        pseudo_terms = []
 302    
 303        for term in lhs_expanded.as_ordered_terms():
 304            print(f"Analyzing term: {term}")
 305    
 306            if isinstance(term, psiOp):
 307                expr = term.args[0]
 308                pseudo_terms.append((1, expr))
 309                print("  --> Classified as pseudo linear term (psiOp)")
 310                continue
 311    
 312            # Otherwise, look for psiOp inside (general case)
 313            if term.has(psiOp):
 314                psiops = term.atoms(psiOp)
 315                for psi in psiops:
 316                    try:
 317                        coeff = simplify(term / psi)
 318                        expr = psi.args[0]
 319                        pseudo_terms.append((coeff, expr))
 320                        print("  --> Classified as pseudo linear term (psiOp)")
 321                    except Exception as e:
 322                        print(f"  ⚠️ Failed to extract psiOp coefficient in term: {term}")
 323                        print(f"     Reason: {e}")
 324                        nonlinear_terms.append(term)
 325                        print("  --> Fallback: classified as nonlinear")
 326                continue
 327    
 328            if term.has(Op):
 329                ops = term.atoms(Op)
 330                for op in ops:
 331                    coeff = term / op
 332                    expr = op.args[0]
 333                    symbol_terms.append((coeff, expr))
 334                    print("  --> Classified as symbolic linear term (Op)")
 335                continue
 336    
 337            if _is_nonlinear_term(term, self.u):
 338                nonlinear_terms.append(term)
 339                print("  --> Classified as nonlinear")
 340                continue
 341    
 342            derivs = term.atoms(Derivative)
 343            if derivs:
 344                deriv = derivs.pop()
 345                coeff = term / deriv
 346                linear_terms[deriv] = linear_terms.get(deriv, 0) + coeff
 347                print(f"  Derivative found: {deriv}")
 348                print("  --> Classified as linear")
 349            elif self.u in term.atoms(Function):
 350                coeff = term.as_coefficients_dict().get(self.u, 1)
 351                linear_terms[self.u] = linear_terms.get(self.u, 0) + coeff
 352                print("  --> Classified as linear")
 353            else:
 354                source_terms.append(term)
 355                print("  --> Classified as source term")
 356    
 357        print(f"Final linear terms: {linear_terms}")
 358        print(f"Final nonlinear terms: {nonlinear_terms}")
 359        print(f"Symbol terms: {symbol_terms}")
 360        print(f"Pseudo terms: {pseudo_terms}")
 361        print(f"Source terms: {source_terms}")
 362    
 363        if pseudo_terms:
 364            # Check if a time derivative is present among the linear terms
 365            has_time_derivative = any(
 366                isinstance(term, Derivative) and self.t in [v for v, _  in term.variable_count]
 367                for term in linear_terms
 368            )
 369            # Extract non-temporal linear terms
 370            invalid_linear_terms = {
 371                term: coeff for term, coeff in linear_terms.items()
 372                if not (
 373                    isinstance(term, Derivative)
 374                    and self.t in [v for v, _  in term.variable_count]
 375                )
 376                and term != self.u  # exclusion of the simple u term (without derivative)
 377            }
 378    
 379            if invalid_linear_terms or symbol_terms:
 380                raise ValueError(
 381                    "When psiOp is used, only nonlinear terms, source terms, "
 382                    "and possibly a time derivative are allowed. "
 383                    "Other linear terms and Ops are forbidden."
 384                )
 385    
 386        return linear_terms, nonlinear_terms, symbol_terms, source_terms, pseudo_terms
 387
 388
 389    def _compute_linear_operator(self):
 390        """
 391        Compute the symbolic Fourier representation L(k) of the linear operator 
 392        derived from the linear part of the PDE.
 393    
 394        This method constructs a dispersion relation by applying each symbolic derivative
 395        to a plane wave exp(i(k·x - ωt)) and extracting the resulting expression.
 396        It handles arbitrary derivative combinations and includes symbolic and
 397        pseudo-differential terms.
 398    
 399        Steps:
 400        -------
 401        1. Construct a plane wave φ(x, t) = exp(i(k·x - ωt)).
 402        2. Apply each term from self.linear_terms to φ.
 403        3. Normalize by φ and simplify to obtain L(k).
 404        4. Include symbolic terms (e.g., psiOp) if present.
 405        5. Detect the temporal order from the dispersion relation.
 406        6. Build the numerical function L(k) via lambdify.
 407    
 408        Sets:
 409        -----
 410        - self.L_symbolic : sympy.Expr
 411            Symbolic form of L(k).
 412        - self.L : callable
 413            Numerical function of L(kx[, ky]).
 414        - self.omega : callable or None
 415            Frequency root ω(k), if available.
 416        - self.temporal_order : int
 417            Order of time derivatives detected.
 418        - self.psi_ops : list of (coeff, PseudoDifferentialOperator)
 419            Pseudo-differential terms present in the equation.
 420    
 421        Raises:
 422        -------
 423        ValueError if the dimension is unsupported or the dispersion relation fails.
 424        """
 425        print("\n*******************************")
 426        print("* Linear operator computation *")
 427        print("*******************************\n")
 428    
 429        # --- Step 1: symbolic variables ---
 430        omega = symbols("omega")
 431        if self.dim == 1:
 432            kvars = [symbols("kx")]
 433            space_vars = [self.x]
 434        elif self.dim == 2:
 435            kvars = symbols("kx ky")
 436            space_vars = [self.x, self.y]
 437        else:
 438            raise ValueError("Only 1D and 2D are supported.")
 439    
 440        kdict = dict(zip(space_vars, kvars))
 441        self.k_symbols = kvars
 442    
 443        # Plane wave expression
 444        phase = sum(k * x for k, x in zip(kvars, space_vars)) - omega * self.t
 445        plane_wave = exp(I * phase)
 446    
 447        # --- Step 2: build lhs expression from linear terms ---
 448        lhs = 0
 449        for deriv, coeff in self.linear_terms.items():
 450            if isinstance(deriv, Derivative):
 451                total_factor = 1
 452                for var, n in deriv.variable_count:
 453                    if var == self.t:
 454                        total_factor *= (-I * omega)**n
 455                    elif var in kdict:
 456                        total_factor *= (I * kdict[var])**n
 457                    else:
 458                        raise ValueError(f"Unknown variable {var} in derivative")
 459                lhs += coeff * total_factor * plane_wave
 460            elif deriv == self.u:
 461                lhs += coeff * plane_wave
 462            else:
 463                raise ValueError(f"Unsupported linear term: {deriv}")
 464    
 465        # --- Step 3: dispersion relation ---
 466        equation = simplify(lhs / plane_wave)
 467        print("\nCharacteristic equation before symbol treatment:")
 468        pprint(equation, num_columns=NUM_COLS)
 469
 470        print("\n--- Symbolic symbol analysis ---")
 471        symb_omega = 0
 472        symb_k = 0
 473        
 474        for coeff, symbol in self.symbol_terms:
 475            if symbol.has(omega):
 476                # Ajouter directement les termes dépendant de omega
 477                symb_omega += coeff * symbol
 478            elif any(symbol.has(k) for k in self.k_symbols):
 479                 symb_k += coeff * symbol.subs(dict(zip(symbol.free_symbols, self.k_symbols)))
 480
 481        print(f"symb_omega: {symb_omega}")
 482        print(f"symb_k: {symb_k}")
 483        
 484        equation = equation + symb_omega + symb_k         
 485
 486        print("\nRaw characteristic equation:")
 487        pprint(equation, num_columns=NUM_COLS)
 488
 489        # Temporal derivative order detection
 490        try:
 491            poly_eq = Eq(equation, 0)
 492            poly = poly_eq.lhs.as_poly(omega)
 493            self.temporal_order = poly.degree() if poly else 0
 494        except Exception as e:
 495            warnings.warn(f"Could not determine temporal order: {e}", RuntimeWarning)
 496            self.temporal_order = 0
 497        print(f"Temporal order from dispersion relation: {self.temporal_order}")
 498        print('self.pseudo_terms = ', self.pseudo_terms)
 499        if self.pseudo_terms:
 500            coeff_time = 1
 501            for term, coeff in self.linear_terms.items():
 502                if isinstance(term, Derivative) and any(var == self.t for var, _  in term.variable_count):
 503                    coeff_time = coeff
 504                    print(f"✅ Time derivative coefficient detected: {coeff_time}")
 505            self.psi_ops = []
 506            for coeff, sym_expr in self.pseudo_terms:
 507                # expr est le Sympy expr. différentiel, var_x la liste [x] ou [x,y]
 508                psi = PseudoDifferentialOperator(sym_expr / coeff_time, self.spatial_vars, self.u, mode='symbol')
 509                
 510                self.psi_ops.append((coeff, psi))
 511        else:
 512            dispersion = solve(Eq(equation, 0), omega)
 513            if not dispersion:
 514                raise ValueError("No solution found for omega")
 515            print("\n--- Solutions found ---")
 516            pprint(dispersion, num_columns=NUM_COLS)
 517        
 518            if self.temporal_order == 2:
 519                omega_expr = simplify(sqrt(dispersion[0]**2))
 520                self.omega_symbolic = omega_expr
 521                self.omega = lambdify(self.k_symbols, omega_expr, "numpy")
 522                self.L_symbolic = -omega_expr**2
 523            else:
 524                self.L_symbolic = -I * dispersion[0]
 525        
 526        
 527            self.L = lambdify(self.k_symbols, self.L_symbolic, "numpy")
 528  
 529            print("\n--- Final linear operator ---")
 530            pprint(self.L_symbolic, num_columns=NUM_COLS)   
 531
 532    def _linear_rhs(self, u, is_v=False):
 533        """
 534        Apply the linear operator (in Fourier space) to the field u or v.
 535
 536        Parameters
 537        ----------
 538        u : np.ndarray
 539            Input solution array.
 540        is_v : bool
 541            Whether to apply the operator to v instead of u.
 542
 543        Returns
 544        -------
 545        np.ndarray
 546            Result of applying the linear operator.
 547        """
 548        if self.dim == 1:
 549            self.symbol_u = np.array(self.L(self.KX), dtype=np.complex128)
 550            self.symbol_v = self.symbol_u  # même opérateur pour u et v
 551        elif self.dim == 2:
 552            self.symbol_u = np.array(self.L(self.KX, self.KY), dtype=np.complex128)
 553            self.symbol_v = self.symbol_u
 554        u_hat = self.fft(u)
 555        u_hat *= self.symbol_v if is_v else self.symbol_u
 556        u_hat *= self.dealiasing_mask
 557        return self.ifft(u_hat)
 558
 559    def setup(self, Lx, Ly=None, Nx=None, Ny=None, Lt=1.0, Nt=100, boundary_condition='periodic',
 560              initial_condition=None, initial_velocity=None, n_frames=100, plot=True):
 561        """
 562        Configure the spatial/temporal grid and initialize the solution field.
 563    
 564        This method sets up the computational domain, initializes spatial and temporal grids,
 565        applies boundary conditions, and prepares symbolic and numerical operators.
 566        It also performs essential analyses such as:
 567        
 568            - CFL condition verification (for stability)
 569            - Symbol analysis (e.g., dispersion relation, regularity)
 570            - Wave propagation analysis for second-order equations
 571    
 572        If pseudo-differential operators (ψOp) are present, symbolic analysis is skipped
 573        in favor of interactive exploration via `interactive_symbol_analysis`.
 574    
 575        Parameters
 576        ----------
 577        Lx : float
 578            Size of the spatial domain along x-axis.
 579        Ly : float, optional
 580            Size of the spatial domain along y-axis (for 2D problems).
 581        Nx : int
 582            Number of spatial points along x-axis.
 583        Ny : int, optional
 584            Number of spatial points along y-axis (for 2D problems).
 585        Lt : float, default=1.0
 586            Total simulation time.
 587        Nt : int, default=100
 588            Number of time steps.
 589        initial_condition : callable
 590            Function returning the initial state u(x, 0) or u(x, y, 0).
 591        initial_velocity : callable, optional
 592            Function returning the initial time derivative ∂ₜu(x, 0) or ∂ₜu(x, y, 0),
 593            required for second-order equations.
 594        n_frames : int, default=100
 595            Number of time frames to store during simulation for visualization or output.
 596    
 597        Raises
 598        ------
 599        ValueError
 600            If mandatory parameters are missing (e.g., Nx not given in 1D, Ly/Ny not given in 2D).
 601    
 602        Notes
 603        -----
 604        - The spatial discretization assumes periodic boundary conditions by default.
 605        - Fourier transforms are computed using real-to-complex FFTs (`scipy.fft.fft`, `fft2`).
 606        - Frequency arrays (`KX`, `KY`) are defined following standard spectral conventions.
 607        - Dealiasing is applied using a sharp cutoff filter at a fraction of the maximum frequency.
 608        - For second-order equations, initial acceleration is derived from the governing operator.
 609        - Symbolic analysis includes plotting of the symbol's real/imaginary/absolute values
 610          and dispersion relation.
 611    
 612        See Also
 613        --------
 614        setup_1D : Sets up internal variables for one-dimensional problems.
 615        setup_2D : Sets up internal variables for two-dimensional problems.
 616        initialize_conditions : Applies initial data and enforces compatibility.
 617        check_cfl_condition : Verifies time step against stability constraints.
 618        plot_symbol : Visualizes the linear operator’s symbol in frequency space.
 619        analyze_wave_propagation : Analyzes group velocity.
 620        interactive_symbol_analysis : Interactive tools for ψOp-based equations.
 621        """
 622        
 623        # Temporal parameters
 624        self.Lt, self.Nt = Lt, Nt
 625        self.dt = Lt / Nt
 626        self.n_frames = n_frames
 627        self.frames = []
 628        self.initial_condition = initial_condition
 629        self.boundary_condition = boundary_condition
 630        self.plot = plot
 631
 632        if self.boundary_condition == 'dirichlet' and not self.has_psi:
 633            raise ValueError(
 634                "Dirichlet boundary conditions require the equation to be defined via a pseudo-differential operator (psiOp). "
 635                "Please provide an equation involving psiOp for non-periodic boundary treatment."
 636            )
 637    
 638        # Dimension checks
 639        if self.dim == 1:
 640            if Nx is None:
 641                raise ValueError("Nx must be specified in 1D.")
 642            self._setup_1D(Lx, Nx)
 643        else:
 644            if None in (Ly, Ny):
 645                raise ValueError("In 2D, Ly and Ny must be provided.")
 646            self._setup_2D(Lx, Ly, Nx, Ny)
 647    
 648        # Initialization of solution and velocities
 649        if not self.is_stationary:
 650            self._initialize_conditions(initial_condition, initial_velocity)
 651            
 652        # Symbol analysis if present
 653        if self.has_psi:
 654            print("⚠️ For psiOp, use interactive_symbol_analysis.")
 655        else:
 656            if self.L_symbolic == 0:
 657                print("⚠️ Linear operator is null.")
 658            else:
 659                self._check_cfl_condition()
 660                self._check_symbol_conditions()
 661                if plot:
 662                	self._plot_symbol()
 663                	if self.temporal_order == 2:
 664                		self._analyze_wave_propagation()
 665
 666    def _setup_1D(self, Lx, Nx):
 667        """
 668        Configure internal variables for one-dimensional (1D) problems.
 669    
 670        This private method initializes spatial and frequency grids, applies dealiasing,
 671        and prepares either pseudo-differential symbols or linear operators for use in time evolution.
 672        
 673        It assumes periodic boundary conditions and uses real-to-complex FFT conventions.
 674        The spatial domain is centered at zero: [-Lx/2, Lx/2].
 675    
 676        Parameters
 677        ----------
 678        Lx : float
 679            Physical size of the spatial domain along the x-axis.
 680        Nx : int
 681            Number of grid points in the x-direction.
 682    
 683        Attributes Set
 684        --------------
 685        - self.Lx : float
 686            Size of the spatial domain.
 687        - self.Nx : int
 688            Number of spatial points.
 689        - self.x_grid : np.ndarray
 690            1D array of spatial coordinates.
 691        - self.X : np.ndarray
 692            Alias to `self.x_grid`, used in physical space computations.
 693        - self.kx : np.ndarray
 694            Array of wavenumbers corresponding to the Fourier transform.
 695        - self.KX : np.ndarray
 696            Alias to `self.kx`, used in frequency space computations.
 697        - self.dealiasing_mask : np.ndarray
 698            Boolean mask used to suppress aliased frequencies during nonlinear calculations.
 699        - self.exp_L : np.ndarray
 700            Exponential of the linear operator scaled by time step: exp(L(k) · dt).
 701        - self.omega_val : np.ndarray
 702            Frequency values ω(k) = Re[√(L(k))] used in second-order time stepping.
 703        - self.cos_omega_dt, self.sin_omega_dt : np.ndarray
 704            Cosine and sine of ω(k)·dt for dispersive propagation.
 705        - self.inv_omega : np.ndarray
 706            Inverse of ω(k), used to avoid division-by-zero in time stepping.
 707    
 708        Notes
 709        -----
 710        - Frequencies are computed using `scipy.fft.fftfreq` and then shifted to center zero frequency.
 711        - Dealiasing is applied using a sharp cutoff filter based on `self.dealiasing_ratio`.
 712        - If pseudo-differential operators (ψOp) are present, symbolic tables are precomputed via `prepare_symbol_tables`.
 713        - For second-order equations, the dispersion relation ω(k) is extracted from the linear operator L(k).
 714    
 715        See Also
 716        --------
 717        setup_2D : Equivalent setup for two-dimensional problems.
 718        prepare_symbol_tables : Precomputes symbolic arrays for ψOp evaluation.
 719        setup_omega_terms : Sets up terms involving ω(k) for second-order evolution.
 720        """
 721        self.Lx, self.Nx = Lx, Nx
 722        self.x_grid = np.linspace(-Lx/2, Lx/2, Nx, endpoint=False)
 723        self.X = self.x_grid
 724        self.kx = 2 * np.pi * fftfreq(Nx, d=Lx / Nx)
 725        self.KX = self.kx
 726    
 727        # Dealiasing mask
 728        k_max = self.dealiasing_ratio * np.max(np.abs(self.kx))
 729        self.dealiasing_mask = (np.abs(self.KX) <= k_max)
 730    
 731        # Preparation of symbol or linear operator
 732        if self.has_psi:
 733            self._prepare_symbol_tables()
 734        else:
 735            L_vals = np.array(self.L(self.KX), dtype=np.complex128)
 736            self.exp_L = np.exp(L_vals * self.dt)
 737            if self.temporal_order == 2:
 738                omega_val = self.omega(self.KX)
 739                self._setup_omega_terms(omega_val)
 740    
 741    def _setup_2D(self, Lx, Ly, Nx, Ny):
 742        """
 743        Configure internal variables for two-dimensional (2D) problems.
 744    
 745        This private method initializes spatial and frequency grids, applies dealiasing,
 746        and prepares either pseudo-differential symbols or linear operators for use in time evolution.
 747        
 748        It assumes periodic boundary conditions and uses real-to-complex FFT conventions.
 749        The spatial domain is centered at zero: [-Lx/2, Lx/2] × [-Ly/2, Ly/2].
 750    
 751        Parameters
 752        ----------
 753        Lx : float
 754            Physical size of the spatial domain along the x-axis.
 755        Ly : float
 756            Physical size of the spatial domain along the y-axis.
 757        Nx : int
 758            Number of grid points along the x-direction.
 759        Ny : int
 760            Number of grid points along the y-direction.
 761    
 762        Attributes Set
 763        --------------
 764        - self.Lx, self.Ly : float
 765            Size of the spatial domain in each direction.
 766        - self.Nx, self.Ny : int
 767            Number of spatial points in each direction.
 768        - self.x_grid, self.y_grid : np.ndarray
 769            1D arrays of spatial coordinates in x and y directions.
 770        - self.X, self.Y : np.ndarray
 771            2D meshgrids of spatial coordinates for physical space computations.
 772        - self.kx, self.ky : np.ndarray
 773            Arrays of wavenumbers corresponding to Fourier transforms in x and y directions.
 774        - self.KX, self.KY : np.ndarray
 775            Meshgrids of wavenumbers used in frequency space computations.
 776        - self.dealiasing_mask : np.ndarray
 777            Boolean mask used to suppress aliased frequencies during nonlinear calculations.
 778        - self.exp_L : np.ndarray
 779            Exponential of the linear operator scaled by time step: exp(L(kx, ky) · dt).
 780        - self.omega_val : np.ndarray
 781            Frequency values ω(kx, ky) = Re[√(L(kx, ky))] used in second-order time stepping.
 782        - self.cos_omega_dt, self.sin_omega_dt : np.ndarray
 783            Cosine and sine of ω(kx, ky)·dt for dispersive propagation.
 784        - self.inv_omega : np.ndarray
 785            Inverse of ω(kx, ky), used to avoid division-by-zero in time stepping.
 786    
 787        Notes
 788        -----
 789        - Frequencies are computed using `scipy.fft.fftfreq` and then shifted to center zero frequency.
 790        - Dealiasing is applied using a sharp cutoff filter based on `self.dealiasing_ratio`.
 791        - If pseudo-differential operators (ψOp) are present, symbolic tables are precomputed via `prepare_symbol_tables`.
 792        - For second-order equations, the dispersion relation ω(kx, ky) is extracted from the linear operator L(kx, ky).
 793    
 794        See Also
 795        --------
 796        setup_1D : Equivalent setup for one-dimensional problems.
 797        prepare_symbol_tables : Precomputes symbolic arrays for ψOp evaluation.
 798        setup_omega_terms : Sets up terms involving ω(kx, ky) for second-order evolution.
 799        """
 800        self.Lx, self.Ly = Lx, Ly
 801        self.Nx, self.Ny = Nx, Ny
 802        self.x_grid = np.linspace(-Lx/2, Lx/2, Nx, endpoint=False)
 803        self.y_grid = np.linspace(-Ly/2, Ly/2, Ny, endpoint=False)
 804        self.X, self.Y = np.meshgrid(self.x_grid, self.y_grid, indexing='ij')
 805        self.kx = 2 * np.pi * fftfreq(Nx, d=Lx / Nx)
 806        self.ky = 2 * np.pi * fftfreq(Ny, d=Ly / Ny)
 807        self.KX, self.KY = np.meshgrid(self.kx, self.ky, indexing='ij')
 808    
 809        # Dealiasing mask
 810        kx_max = self.dealiasing_ratio * np.max(np.abs(self.kx))
 811        ky_max = self.dealiasing_ratio * np.max(np.abs(self.ky))
 812        self.dealiasing_mask = (np.abs(self.KX) <= kx_max) & (np.abs(self.KY) <= ky_max)
 813    
 814        # Preparation of symbol or linear operator
 815        if self.has_psi:
 816            self._prepare_symbol_tables()
 817        else:
 818            L_vals = self.L(self.KX, self.KY)
 819            self.exp_L = np.exp(L_vals * self.dt)
 820            if self.temporal_order == 2:
 821                omega_val = self.omega(self.KX, self.KY)
 822                self._setup_omega_terms(omega_val)
 823    
 824    def _setup_omega_terms(self, omega_val):
 825        """
 826        Initialize terms derived from the angular frequency ω for time evolution.
 827    
 828        This private method precomputes and stores key trigonometric and inverse quantities
 829        based on the dispersion relation ω(k), used in second-order time integration schemes.
 830        
 831        These values are essential for solving wave-like equations with dispersive behavior:
 832            cos(ω·dt), sin(ω·dt), 1/ω
 833        
 834        The inverse frequency is computed safely to avoid division by zero.
 835    
 836        Parameters
 837        ----------
 838        omega_val : np.ndarray
 839            Array of angular frequency values ω(k) evaluated at discrete wavenumbers.
 840            Can be one-dimensional (1D) or two-dimensional (2D) depending on spatial dimension.
 841    
 842        Attributes Set
 843        --------------
 844        - self.omega_val : np.ndarray
 845            Copy of the input angular frequency array.
 846        - self.cos_omega_dt : np.ndarray
 847            Cosine of ω(k) multiplied by time step: cos(ω(k) · dt).
 848        - self.sin_omega_dt : np.ndarray
 849            Sine of ω(k) multiplied by time step: sin(ω(k) · dt).
 850        - self.inv_omega : np.ndarray
 851            Inverse of ω(k), with zeros where ω(k) == 0 to avoid division by zero.
 852    
 853        Notes
 854        -----
 855        - This method is typically called during setup when solving second-order PDEs
 856          involving dispersive waves (e.g., Klein-Gordon, Schrödinger, or water wave equations).
 857        - The safe computation of 1/ω ensures numerical stability even when low frequencies are present.
 858        - These precomputed arrays are used in spectral propagators for accurate time stepping.
 859    
 860        See Also
 861        --------
 862        setup_1D : Sets up internal variables for one-dimensional problems.
 863        setup_2D : Sets up internal variables for two-dimensional problems.
 864        solve : Time integration using the computed frequency terms.
 865        """
 866        self.omega_val = omega_val
 867        self.cos_omega_dt = np.cos(omega_val * self.dt)
 868        self.sin_omega_dt = np.sin(omega_val * self.dt)
 869        self.inv_omega = np.zeros_like(omega_val)
 870        nonzero = omega_val != 0
 871        self.inv_omega[nonzero] = 1.0 / omega_val[nonzero]
 872
 873    def _evaluate_source_at_t0(self):
 874        """
 875        Evaluate source terms at initial time t = 0 over the spatial grid.
 876    
 877        This private method computes the total contribution of all source terms at the initial time,
 878        evaluated across the entire spatial domain. It supports both one-dimensional (1D) and
 879        two-dimensional (2D) configurations.
 880    
 881        Returns
 882        -------
 883        np.ndarray
 884            A numpy array representing the evaluated source term at t=0:
 885            - In 1D: Shape (Nx,), evaluated at each x in `self.x_grid`.
 886            - In 2D: Shape (Nx, Ny), evaluated at each (x, y) pair in the grid.
 887    
 888        Notes
 889        -----
 890        - The symbolic expressions in `self.source_terms` are substituted with numerical values at t=0.
 891        - In 1D, each term is evaluated at (t=0, x=x_val).
 892        - In 2D, each term is evaluated at (t=0, x=x_val, y=y_val).
 893        - Evaluated using SymPy's `evalf()` to ensure numeric conversion.
 894        - This method assumes that the source terms have already been lambdified or are compatible with symbolic substitution.
 895    
 896        See Also
 897        --------
 898        setup : Initializes the spatial grid and source terms.
 899        solve : Uses this evaluation during the first time step.
 900        """
 901        if self.dim == 1:
 902            # Evaluation on the 1D spatial grid
 903            return np.array([
 904                sum(term.subs(self.t, 0).subs(self.x, x_val).evalf()
 905                    for term in self.source_terms)
 906                for x_val in self.x_grid
 907            ], dtype=np.float64)
 908        else:
 909            # Evaluation on the 2D spatial grid
 910            return np.array([
 911                [sum(term.subs({self.t: 0, self.x: x_val, self.y: y_val}).evalf()
 912                      for term in self.source_terms)
 913                 for y_val in self.y_grid]
 914                for x_val in self.x_grid
 915            ], dtype=np.float64)
 916    
 917    def _initialize_conditions(self, initial_condition, initial_velocity):
 918        """
 919        Initialize the solution and velocity fields at t = 0.
 920    
 921        This private method sets up the initial state of the solution `u_prev` and, if applicable,
 922        the time derivative (velocity) `v_prev` for second-order evolution equations.
 923        
 924        For second-order equations, it also computes the backward-in-time value `u_prev2`
 925        needed by the Leap-Frog method. The acceleration at t = 0 is computed from:
 926            ∂ₜ²u = L(u) + N(u) + f(x, t=0)
 927        where L is the linear operator, N is the nonlinear term, and f is the source term.
 928    
 929        Parameters
 930        ----------
 931        initial_condition : callable
 932            Function returning the initial condition u(x, 0) or u(x, y, 0).
 933        initial_velocity : callable or None
 934            Function returning the initial velocity ∂ₜu(x, 0) or ∂ₜu(x, y, 0). Required for
 935            second-order equations; ignored otherwise.
 936    
 937        Raises
 938        ------
 939        ValueError
 940            If `initial_velocity` is not provided for second-order equations.
 941    
 942        Notes
 943        -----
 944        - Applies periodic boundary conditions after setting initial data.
 945        - Stores a copy of the initial state in `self.frames` for visualization/output.
 946        - In second-order systems, initializes `self.u_prev2` using a Taylor expansion:
 947          u_prev2 = u_prev - dt * v_prev + 0.5 * dt² * (∂ₜ²u)
 948    
 949        See Also
 950        --------
 951        apply_boundary : Enforces periodic boundary conditions on the solution field.
 952        psiOp_apply : Computes pseudo-differential operator action for acceleration.
 953        linear_rhs : Evaluates linear part of the equation in Fourier space.
 954        apply_nonlinear : Handles nonlinear terms with spectral differentiation.
 955        evaluate_source_at_t0 : Evaluates source terms at the initial time.
 956        """
 957        # Initial condition
 958        if self.dim == 1:
 959            self.u_prev = initial_condition(self.X)
 960        else:
 961            self.u_prev = initial_condition(self.X, self.Y)
 962        self._apply_boundary(self.u_prev)
 963    
 964        # Initial velocity (second order)
 965        if self.temporal_order == 2:
 966            if initial_velocity is None:
 967                raise ValueError("Initial velocity is required for second-order equations.")
 968            if self.dim == 1:
 969                self.v_prev = initial_velocity(self.X)
 970            else:
 971                self.v_prev = initial_velocity(self.X, self.Y)
 972            self.u0 = np.copy(self.u_prev)
 973            self.v0 = np.copy(self.v_prev)
 974    
 975            # Calculation of u_prev2 (initial acceleration)
 976            if not hasattr(self, 'u_prev2'):
 977                if self.has_psi:
 978                    acc0 = -self._apply_psiOp(self.u_prev)
 979                else:
 980                    acc0 = self._linear_rhs(self.u_prev, is_v=False)
 981                rhs_nl = self._apply_nonlinear(self.u_prev, is_v=False)
 982                acc0 += rhs_nl
 983                if hasattr(self, 'source_terms') and self.source_terms:
 984                    acc0 += self._evaluate_source_at_t0()
 985                self.u_prev2 = self.u_prev - self.dt * self.v_prev + 0.5 * self.dt**2 * acc0
 986    
 987        self.frames = [self.u_prev.copy()]
 988           
 989    def _apply_boundary(self, u):
 990        """
 991        Apply boundary conditions to the solution array based on the specified type.
 992    
 993        This method supports two types of boundary conditions:
 994        
 995        - 'periodic': Enforces periodicity by copying opposite boundary values.
 996        - 'dirichlet': Sets all boundary values to zero (homogeneous Dirichlet condition).
 997    
 998        Parameters
 999        ----------
1000        u : np.ndarray
1001            The solution array representing the field values on a spatial grid.
1002            In 1D, shape must be (Nx,). In 2D, shape must be (Nx, Ny).
1003    
1004        Raises
1005        ------
1006        ValueError
1007            If `self.boundary_condition` is not one of {'periodic', 'dirichlet'}.
1008    
1009        Notes
1010        -----
1011        - For 'periodic':
1012            * In 1D: u[0] = u[-2], u[-1] = u[1]
1013            * In 2D: First and last rows/columns are set equal to their neighbors.
1014        - For 'dirichlet':
1015            * All boundary points are explicitly set to zero.
1016        """
1017    
1018        if self.boundary_condition == 'periodic':
1019            if self.dim == 1:
1020                u[0] = u[-2]
1021                u[-1] = u[1]
1022            elif self.dim == 2:
1023                u[0, :] = u[-2, :]
1024                u[-1, :] = u[1, :]
1025                u[:, 0] = u[:, -2]
1026                u[:, -1] = u[:, 1]
1027    
1028        elif self.boundary_condition == 'dirichlet':
1029            if self.dim == 1:
1030                u[0] = 0
1031                u[-1] = 0
1032            elif self.dim == 2:
1033                u[0, :] = 0
1034                u[-1, :] = 0
1035                u[:, 0] = 0
1036                u[:, -1] = 0
1037    
1038        else:
1039            raise ValueError(
1040                f"Invalid boundary condition '{self.boundary_condition}'. "
1041                "Supported types are 'periodic' and 'dirichlet'."
1042            )
1043
1044    def _apply_nonlinear(self, u, is_v=False):
1045        """
1046        Apply nonlinear terms to the solution using spectral differentiation with dealiasing.
1047
1048        This method evaluates all nonlinear terms present in the PDE by substituting spatial 
1049        derivatives with their spectral approximations computed via FFT. The dealiasing mask 
1050        ensures numerical stability by removing high-frequency components that could lead 
1051        to aliasing errors.
1052
1053        Parameters
1054        ----------
1055        u : numpy.ndarray
1056            Current solution array on the spatial grid.
1057        is_v : bool
1058            If True, evaluates nonlinear terms for the velocity field v instead of u.
1059
1060        Returns:
1061            numpy.ndarray: Array representing the contribution of nonlinear terms multiplied by dt.
1062
1063        Notes:
1064        
1065        - In 1D, computes ∂ₓu via FFT and substitutes any derivative term in the nonlinear expressions.
1066        - In 2D, computes ∂ₓu and ∂ᵧu via FFT and performs similar substitutions.
1067        - Uses lambdify to evaluate symbolic nonlinear expressions numerically.
1068        - Derivatives are replaced symbolically with 'u_x' and 'u_y' before evaluation.
1069        """
1070        if not self.nonlinear_terms:
1071            return np.zeros_like(u, dtype=np.complex128)
1072        
1073        nonlinear_term = np.zeros_like(u, dtype=np.complex128)
1074    
1075        if self.dim == 1:
1076            u_hat = self.fft(u)
1077            u_hat *= self.dealiasing_mask
1078            u = self.ifft(u_hat)
1079    
1080            u_x_hat = (1j * self.KX) * u_hat
1081            u_x = self.ifft(u_x_hat)
1082    
1083            for term in self.nonlinear_terms:
1084                term_replaced = term
1085                if term.has(Derivative):
1086                    for deriv in term.atoms(Derivative):
1087                        if deriv.args[1][0] == self.x:
1088                            term_replaced = term_replaced.subs(deriv, symbols('u_x'))            
1089                term_func = lambdify((self.t, self.x, self.u_eq, 'u_x'), term_replaced, 'numpy')
1090                if is_v:
1091                    nonlinear_term += term_func(0, self.X, self.v_prev, u_x)
1092                else:
1093                    nonlinear_term += term_func(0, self.X, u, u_x)
1094    
1095        elif self.dim == 2:
1096            u_hat = self.fft(u)
1097            u_hat *= self.dealiasing_mask
1098            u = self.ifft(u_hat)
1099    
1100            u_x_hat = (1j * self.KX) * u_hat
1101            u_y_hat = (1j * self.KY) * u_hat
1102            u_x = self.ifft(u_x_hat)
1103            u_y = self.ifft(u_y_hat)
1104    
1105            for term in self.nonlinear_terms:
1106                term_replaced = term
1107                if term.has(Derivative):
1108                    for deriv in term.atoms(Derivative):
1109                        if deriv.args[1][0] == self.x:
1110                            term_replaced = term_replaced.subs(deriv, symbols('u_x'))
1111                        elif deriv.args[1][0] == self.y:
1112                            term_replaced = term_replaced.subs(deriv, symbols('u_y'))
1113                term_func = lambdify((self.t, self.x, self.y, self.u_eq, 'u_x', 'u_y'), term_replaced, 'numpy')
1114                if is_v:
1115                    nonlinear_term += term_func(0, self.X, self.Y, self.v_prev, u_x, u_y)
1116                else:
1117                    nonlinear_term += term_func(0, self.X, self.Y, u, u_x, u_y)
1118        else:
1119            raise ValueError("Unsupported spatial dimension.")
1120        
1121        return nonlinear_term * self.dt
1122
1123    def _prepare_symbol_tables(self):
1124        """
1125        Precompute and store evaluated pseudo-differential operator symbols for spectral methods.
1126
1127        This method evaluates all pseudo-differential operators (ψOp) present in the PDE
1128        over the spatial and frequency grids, scales them by their respective coefficients,
1129        and combines them into a single composite symbol used in time-stepping and inversion.
1130
1131        The evaluation is performed via the `evaluate` method of each PseudoDifferentialOperator,
1132        which computes p(x, ξ) or p(x, y, ξ, η) numerically over the current grid configuration.
1133
1134        Side Effects:
1135            self.precomputed_symbols : list of (coeff, symbol_array)
1136                Each tuple contains a coefficient and its evaluated symbol on the grid.
1137            self.combined_symbol : np.ndarray
1138                Sum of all scaled symbol arrays: ∑(coeffₖ * ψₖ(x, ξ))
1139
1140        Raises:
1141            ValueError: If the spatial dimension is not 1D or 2D.
1142        """
1143        self.precomputed_symbols = []
1144        self.combined_symbol = 0
1145        for coeff, psi in self.psi_ops:
1146            if self.dim == 1:
1147                raw = psi.evaluate(self.X, None, self.KX, None)
1148            elif self.dim == 2:
1149                raw = psi.evaluate(self.X, self.Y, self.KX, self.KY)
1150            else:
1151                raise ValueError('Unsupported spatial dimension.')
1152            raw_flat = raw.flatten()
1153            converted = np.array([complex(N(val)) for val in raw_flat], dtype=np.complex128)
1154            raw_eval = converted.reshape(raw.shape)
1155            self.precomputed_symbols.append((coeff, raw_eval))
1156        self.combined_symbol = sum((coeff * sym for coeff, sym in self.precomputed_symbols))
1157        self.combined_symbol = np.array(self.combined_symbol, dtype=np.complex128)
1158
1159    def _total_symbol_expr(self):
1160        """
1161        Compute the total pseudo-differential symbol expression from all pseudo_terms.
1162
1163        This method constructs the full symbol of the pseudo-differential operator
1164        by summing up all coefficient-weighted symbolic expressions.
1165
1166        The result is cached in self.symbol_expr to avoid recomputation.
1167
1168        Returns:
1169            sympy.Expr: The combined symbol expression, representing the full
1170                        pseudo-differential operator in symbolic form.
1171
1172        Example:
1173            Given pseudo_terms = [(2, ξ²), (1, x·ξ)], this returns 2·ξ² + x·ξ.
1174        """
1175        if not hasattr(self, '_symbol_expr'):
1176            self.symbol_expr = sum(coeff * expr for coeff, expr in self.pseudo_terms)
1177        return self.symbol_expr
1178
1179    def _build_symbol_func(self, expr):
1180        """
1181        Build a numerical evaluation function from a symbolic pseudo-differential operator expression.
1182    
1183        This method converts a symbolic expression representing a pseudo-differential operator into
1184        a callable NumPy-compatible function. The function accepts spatial and frequency variables
1185        depending on the dimensionality of the problem.
1186    
1187        Parameters
1188        ----------
1189        expr : sympy expression
1190            A SymPy expression representing the symbol of the pseudo-differential operator. It may depend on spatial variables (x, y) and frequency variables (xi, eta).
1191    
1192        Returns:
1193            function : A lambdified function that takes:
1194            
1195                - In 1D: `(x, xi)` — spatial coordinate and frequency.
1196                - In 2D: `(x, y, xi, eta)` — spatial coordinates and frequencies.
1197                
1198              Returns a NumPy array of evaluated symbol values over input grids.
1199    
1200        Notes:
1201            - Uses `lambdify` from SymPy with the `'numpy'` backend for efficient vectorized evaluation.
1202            - Real variable assumptions are enforced to ensure proper behavior in numerical contexts.
1203            - Used internally by methods like `apply_psiOp`, `evaluate`, and visualization tools.
1204        """
1205        if self.dim == 1:
1206            x, xi = symbols('x xi', real=True)
1207            return lambdify((x, xi), expr, 'numpy')
1208        else:
1209            x, y, xi, eta = symbols('x y xi eta', real=True)
1210            return lambdify((x, y, xi, eta), expr, 'numpy')
1211
1212    def _apply_psiOp(self, u):
1213        """
1214        Apply the pseudo-differential operator to the input field u.
1215    
1216        This method dispatches the application of the pseudo-differential operator based on:
1217        
1218        - Whether the symbol is spatially dependent (x/y)
1219        - The boundary condition in use (periodic or dirichlet)
1220    
1221        Supported operations:
1222        
1223        - Constant-coefficient symbols: applied via Fourier multiplication.
1224        - Spatially varying symbols: applied via Kohn–Nirenberg quantization.
1225        - Dirichlet boundary conditions: handled with non-periodic convolution-like quantization.
1226    
1227        Dispatch Logic:\n
1228        if not self.is_spatial: u ↦ Op(p)(D) ⋅ u = 𝓕⁻¹[ p(ξ) ⋅ 𝓕(u) ]\n
1229        elif periodic: u ↦ Op(p)(x,D) ⋅ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ based of FFT (quicker)\n
1230        elif dirichlet: u ↦ Op(p)(x,D) ⋅ u ≈ u ≈ ∫ eᶦˣᶿ p(x, ξ) 𝓕(u)(ξ) dξ (slower)\n
1231        
1232        This method delegates to the apply() method of each 
1233        PseudoDifferentialOperator instance.
1234        
1235        Parameters
1236        ----------
1237        u : ndarray
1238            Function to which operators are applied
1239            
1240        Returns
1241        -------
1242        ndarray
1243            Result of applying all operators with their coefficients
1244        """
1245        if not hasattr(self, 'psi_ops') or not self.psi_ops:
1246            raise ValueError("No pseudo-differential operators defined")
1247        
1248        result = np.zeros_like(u, dtype=np.complex128)
1249        
1250        for coeff, psi_op in self.psi_ops:
1251            coeff = np.complex128(coeff)
1252            if self.dim == 1:
1253                contribution = psi_op.apply(
1254                    u=u,
1255                    x_grid=self.x_grid,
1256                    kx=self.kx,
1257                    boundary_condition=self.boundary_condition,
1258                    dealiasing_mask=self.dealiasing_mask
1259                )
1260            elif self.dim == 2:
1261                contribution = psi_op.apply(
1262                    u=u,
1263                    x_grid=self.x_grid,
1264                    kx=self.kx,
1265                    y_grid=self.y_grid,
1266                    ky=self.ky,
1267                    boundary_condition=self.boundary_condition,
1268                    dealiasing_mask=self.dealiasing_mask
1269                )
1270            else:
1271                raise ValueError("Only 1D and 2D supported")
1272            
1273            result += coeff * contribution
1274        
1275        return result
1276
1277    def _step_order1_with_psi(self, source_contribution):
1278        """
1279        Perform one time step of a first-order evolution using a pseudo-differential operator.
1280    
1281        This method updates the solution field using an exponential integrator or explicit Euler scheme,
1282        depending on boundary conditions and the structure of the pseudo-differential symbol.
1283        It supports:
1284        - Linear dynamics via pseudo-differential operator L (possibly nonlocal)
1285        - Nonlinear terms computed via spectral differentiation
1286        - External source contributions
1287    
1288        The update follows **three distinct computational paths**:
1289    
1290        1. **Periodic boundaries + diagonalizable symbol**  
1291           Symbol is constant in space → use direct Fourier-based exponential integrator:  
1292               uₙ₊₁ = e⁻ᴸΔᵗ ⋅ uₙ + Δt ⋅ φ₁(−LΔt) ⋅ (N(uₙ) + F)
1293    
1294        2. **Non-diagonalizable but spatially uniform symbol**  
1295           General exponential time differencing of order 1:  
1296               uₙ₊₁ = eᴸΔᵗ ⋅ uₙ + Δt ⋅ φ₁(LΔt) ⋅ (N(uₙ) + F)
1297    
1298        3. **Spatially varying symbol**  
1299           No frequency diagonalization available → use explicit Euler:  
1300               uₙ₊₁ = uₙ + Δt ⋅ (L(uₙ) + N(uₙ) + F)
1301    
1302        where:
1303            L(uₙ) = linear part via pseudo-differential operator
1304            N(uₙ) = nonlinear contribution at current time step
1305            F     = external source term
1306            Δt    = time step size
1307            φ₁(z) = (eᶻ − 1)/z (with safe handling near z=0)
1308    
1309        Boundary conditions are applied after each update to ensure consistency.
1310    
1311        Parameters
1312            source_contribution (np.ndarray): Array representing the external source term at current time step.
1313                                              Must match the spatial dimensions of self.u_prev.
1314    
1315        Returns:
1316            np.ndarray: Updated solution array after one time step.
1317        """
1318        # Handling null source
1319        if np.isscalar(source_contribution):
1320            source = np.zeros_like(self.u_prev)
1321        else:
1322            source = source_contribution
1323
1324        def _spectral_filter(u, cutoff=0.8):
1325            if u.ndim == 1:
1326                u_hat = self.fft(u)
1327                N = len(u)
1328                k = fftfreq(N)
1329                mask = np.exp(-(k / cutoff)**8)
1330                return self.ifft(u_hat * mask).real
1331            elif u.ndim == 2:
1332                u_hat = self.fft(u)
1333                Ny, Nx = u.shape
1334                ky = fftfreq(Ny)[:, None]
1335                kx = fftfreq(Nx)[None, :]
1336                k_squared = kx**2 + ky**2
1337                mask = np.exp(-(np.sqrt(k_squared) / cutoff)**8)
1338                return self.ifft(u_hat * mask).real
1339            else:
1340                raise ValueError("Only 1D and 2D arrays are supported.")
1341
1342        # Recalculate symbol if necessary
1343        if self.is_spatial:
1344            self._prepare_symbol_tables()  # Recalculates self.combined_symbol
1345    
1346        # Case with FFT (symbol diagonalizable in Fourier space)
1347        if self.boundary_condition == 'periodic' and not self.is_spatial:
1348            u_hat = self.fft(self.u_prev)
1349            u_hat *= np.exp(-self.dt * self.combined_symbol)
1350            u_hat *= self.dealiasing_mask
1351            u_symb = self.ifft(u_hat)
1352            u_nl = self._apply_nonlinear(self.u_prev)
1353            u_new = u_symb + u_nl + source
1354        else:
1355            if not self.is_spatial:
1356                # General case with ETD1
1357                u_nl = self._apply_nonlinear(self.u_prev)
1358    
1359                # Calculation of exp(dt * L) and phi1(dt * L)
1360                L_vals = self.combined_symbol  # Uses the updated symbol
1361                exp_L = np.exp(-self.dt * L_vals)
1362                phi1_L = (exp_L - 1.0) / (self.dt * L_vals)
1363                phi1_L[np.isnan(phi1_L)] = 1.0  # Handling division by zero
1364    
1365                # Fourier transform
1366                u_hat = self.fft(self.u_prev)
1367                u_nl_hat = self.fft(u_nl)
1368                source_hat = self.fft(source)
1369    
1370                # Assembling the solution in Fourier space
1371                u_hat_new = exp_L * u_hat + self.dt * phi1_L * (u_nl_hat + source_hat)
1372                u_new = self.ifft(u_hat_new)
1373            else:
1374                # if the symbol depends on spatial variables : Euler method
1375                Lu_prev = -self._apply_psiOp(self.u_prev)
1376                u_nl = self._apply_nonlinear(self.u_prev)
1377                u_new = self.u_prev + self.dt * (Lu_prev + u_nl + source)
1378                u_new = _spectral_filter(u_new, cutoff=self.dealiasing_ratio)
1379        # Applying boundary conditions
1380        self._apply_boundary(u_new)
1381        return u_new
1382
1383    def _step_order2_with_psi(self, source_contribution):
1384        """
1385        Perform one time step of a second-order time evolution using a pseudo-differential operator.
1386    
1387        This method updates the solution field using a second-order accurate scheme suitable for wave-like equations.
1388        The update includes contributions from:
1389        - Linear dynamics via a pseudo-differential operator (e.g., dispersion or stiffness)
1390        - Nonlinear terms computed via spectral differentiation
1391        - External source contributions
1392    
1393        Discretization follows a leapfrog-style finite difference in time:
1394        
1395            uₙ₊₁ = 2uₙ − uₙ₋₁ + Δt² ⋅ (L(uₙ) + N(uₙ) + F)
1396    
1397        where:
1398            L(uₙ) = linear part evaluated via pseudo-differential operator
1399            N(uₙ) = nonlinear contribution at current time step
1400            F     = external source term at current time step
1401            Δt    = time step size
1402    
1403        Boundary conditions are applied after each update to ensure consistency.
1404    
1405        Parameters
1406            source_contribution (np.ndarray): Array representing the external source term at current time step.
1407                                              Must match the spatial dimensions of self.u_prev.
1408    
1409        Returns:
1410            np.ndarray: Updated solution array after one time step.
1411        """
1412        Lu_prev = -self._apply_psiOp(self.u_prev)
1413        rhs_nl = self._apply_nonlinear(self.u_prev, is_v=False)
1414        u_new = 2 * self.u_prev - self.u_prev2 + self.dt ** 2 * (Lu_prev + rhs_nl + source_contribution)
1415        self._apply_boundary(u_new)
1416        self.u_prev2 = self.u_prev
1417        self.u_prev = u_new
1418        self.u = u_new
1419        return u_new
1420
1421    def solve(self):
1422        """
1423        Solve the partial differential equation numerically using spectral methods.
1424        
1425        This method evolves the solution in time using a combination of:
1426        - Fourier-based linear evolution (with dealiasing)
1427        - Nonlinear term handling via pseudo-spectral evaluation
1428        - Support for pseudo-differential operators (psiOp)
1429        - Source terms and boundary conditions
1430        
1431        The solver supports:
1432        - 1D and 2D spatial domains
1433        - First and second-order time evolution
1434        - Periodic and Dirichlet boundary conditions
1435        - Time-stepping schemes: default, ETD-RK4
1436        
1437        Returns:
1438            list[np.ndarray]: A list of solution arrays at each saved time frame.
1439        
1440        Side Effects:
1441            - Updates self.frames: stores solution snapshots
1442            - Updates self.energy_history: records total energy if enabled
1443            
1444        Algorithm Overview:
1445            For each time step:
1446                1. Evaluate source contributions (if any)
1447                2. Apply time evolution:
1448                    - Order 1:
1449                        - With psiOp: uses step_order1_with_psi
1450                        - With ETD-RK4: exponential time differencing
1451                        - Default: linear + nonlinear update
1452                    - Order 2:
1453                        - With psiOp: uses step_order2_with_psi
1454                        - With ETD-RK4: second-order exponential scheme
1455                        - Default: second-order leapfrog-style update
1456                3. Enforce boundary conditions
1457                4. Save solution snapshot periodically
1458                5. Record energy (for second-order systems without psiOp)
1459        """
1460        print('\n*******************')
1461        print('* Solving the PDE *')
1462        print('*******************\n')
1463        save_interval = max(1, self.Nt // self.n_frames)
1464        self.energy_history = []
1465        for step in range(self.Nt):
1466            if hasattr(self, 'source_terms') and self.source_terms:
1467                source_contribution = np.zeros_like(self.X, dtype=np.float64)
1468                for term in self.source_terms:
1469                    try:
1470                        if self.dim == 1:
1471                            source_func = lambdify((self.t, self.x), term, 'numpy')
1472                            source_contribution += source_func(step * self.dt, self.X)
1473                        elif self.dim == 2:
1474                            source_func = lambdify((self.t, self.x, self.y), term, 'numpy')
1475                            source_contribution += source_func(step * self.dt, self.X, self.Y)
1476                    except Exception as e:
1477                        print(f'Error evaluating source term {term}: {e}')
1478            else:
1479                source_contribution = 0
1480
1481            if self.temporal_order == 1:
1482                if self.has_psi:
1483                    u_new = self._step_order1_with_psi(source_contribution)
1484                elif hasattr(self, 'time_scheme') and self.time_scheme == 'ETD-RK4':
1485                    u_new = self._step_ETD_RK4(self.u_prev)
1486                else:
1487                    u_hat = self.fft(self.u_prev)
1488                    u_hat *= self.exp_L
1489                    u_hat *= self.dealiasing_mask
1490                    u_lin = self.ifft(u_hat)
1491                    u_nl = self._apply_nonlinear(u_lin)
1492                    u_new = u_lin + u_nl + source_contribution
1493                self._apply_boundary(u_new)
1494                self.u_prev = u_new
1495
1496            elif self.temporal_order == 2:
1497                if self.has_psi:
1498                    u_new = self._step_order2_with_psi(source_contribution)
1499                else:
1500                    if hasattr(self, 'time_scheme') and self.time_scheme == 'ETD-RK4':
1501                        u_new, v_new = self._step_ETD_RK4_order2(self.u_prev, self.v_prev)
1502                    else:
1503                        u_hat = self.fft(self.u_prev)
1504                        v_hat = self.fft(self.v_prev)
1505                        u_new_hat = self.cos_omega_dt * u_hat + self.sin_omega_dt * self.inv_omega * v_hat
1506                        v_new_hat = -self.omega_val * self.sin_omega_dt * u_hat + self.cos_omega_dt * v_hat
1507                        u_new = self.ifft(u_new_hat)
1508                        v_new = self.ifft(v_new_hat)
1509                        u_nl = self._apply_nonlinear(self.u_prev, is_v=False)
1510                        v_nl = self._apply_nonlinear(self.v_prev, is_v=True)
1511                        u_new += (u_nl + source_contribution) * self.dt ** 2 / 2
1512                        v_new += (u_nl + source_contribution) * self.dt
1513                    self._apply_boundary(u_new)
1514                    self._apply_boundary(v_new)
1515                    self.u_prev = u_new
1516                    self.v_prev = v_new
1517
1518            if step % save_interval == 0:
1519                self.frames.append(self.u_prev.copy())
1520
1521            if self.temporal_order == 2 and (not self.has_psi):
1522                E = self._compute_energy()
1523                self.energy_history.append(E)
1524
1525        return self.frames  
1526                
1527    def solve_stationary_psiOp(self, order=3):
1528        """
1529        Solve stationary pseudo-differential equations of the form P[u] = f(x) or P[u] = f(x,y) using asymptotic inversion.
1530    
1531        This method computes the solution to a stationary (time-independent) pseudo-differential equation
1532        where the operator P is defined via symbolic expressions (psiOp). It constructs an asymptotic right inverse R 
1533        such that P∘R ≈ Id, then applies it to the source term f using either direct Fourier multiplication 
1534        (when the symbol is spatially independent) or Kohn–Nirenberg quantization (when spatial dependence is present).
1535    
1536        The inversion is based on the principal symbol of the operator and its asymptotic expansion up to the given order.
1537        Ellipticity of the symbol is checked numerically before inversion to ensure well-posedness.
1538    
1539        Parameters
1540        ----------
1541        order : int, default=3
1542            Order of the asymptotic expansion used to construct the right inverse of the pseudo-differential operator.
1543        method : str, optional
1544            Inversion strategy:
1545            - 'diagonal' (default): Fast approximate inversion using diagonal operators in frequency space.
1546            - 'full'                : Pointwise exact inversion (slower but more accurate).
1547    
1548        Returns
1549        -------
1550        ndarray
1551            The computed solution u(x) in 1D or u(x, y) in 2D as a NumPy array over the spatial grid.
1552    
1553        Raises
1554        ------
1555        ValueError
1556            If no pseudo-differential operator (psiOp) is defined.
1557            If linear or nonlinear terms other than psiOp are present.
1558            If the symbol is not elliptic on the grid.
1559            If no source term is provided for the right-hand side.
1560    
1561        Notes
1562        -----
1563        - The method assumes the problem is fully stationary: time derivatives must be absent.
1564        - Requires the equation to be purely pseudo-differential (no Op, Derivative, or nonlinear terms).
1565        - Symbol evaluation and inversion are dimension-aware (supports both 1D and 2D problems).
1566        - Supports optimization paths when the symbol does not depend on spatial variables.
1567    
1568        See Also
1569        --------
1570        right_inverse_asymptotic : Constructs the asymptotic inverse of the pseudo-differential operator.
1571        kohn_nirenberg           : Numerical implementation of general pseudo-differential operators.
1572        is_elliptic_numerically  : Verifies numerical ellipticity of the symbol.
1573        """
1574
1575        print("\n*******************************")
1576        print("* Solving the stationnary PDE *")
1577        print("*******************************\n")
1578        print("boundary condition: ",self.boundary_condition)
1579        
1580
1581        if not self.has_psi:
1582            raise ValueError("Only supports problems with psiOp.")
1583    
1584        if self.linear_terms or self.nonlinear_terms:
1585            raise ValueError("Stationary psiOp problems must be linear and purely pseudo-differential.")
1586
1587        if self.boundary_condition not in ('periodic', 'dirichlet'):
1588            raise ValueError(
1589                "For stationary PDEs, boundary conditions must be explicitly defined. "
1590                "Supported types are 'periodic' and 'dirichlet'."
1591            )    
1592            
1593        if self.dim == 1:
1594            x = self.x
1595            xi = symbols('xi', real=True)
1596            spatial_vars = (x,)
1597            freq_vars = (xi,)
1598            X, KX = self.X, self.KX
1599        elif self.dim == 2:
1600            x, y = self.x, self.y
1601            xi, eta = symbols('xi eta', real=True)
1602            spatial_vars = (x, y)
1603            freq_vars = (xi, eta)
1604            X, Y, KX, KY = self.X, self.Y, self.KX, self.KY
1605        else:
1606            raise ValueError("Unsupported spatial dimension.")
1607    
1608        total_symbol = sum(coeff * psi.expr for coeff, psi in self.psi_ops)
1609        psi_total = PseudoDifferentialOperator(total_symbol, spatial_vars, mode='symbol')
1610    
1611        # Check ellipticity
1612        if self.dim == 1:
1613            is_elliptic = psi_total.is_elliptic_numerically(X, KX)
1614        else:
1615            is_elliptic = psi_total.is_elliptic_numerically((X[:, 0], Y[0, :]), (KX[:, 0], KY[0, :]))
1616        if not is_elliptic:
1617            raise ValueError("❌ The pseudo-differential symbol is not numerically elliptic on the grid.")
1618        print("✅ Elliptic pseudo-differential symbol: inversion allowed.")
1619    
1620        R_symbol = psi_total.right_inverse_asymptotic(order=order)
1621        print('Right inverse asymptotic symbol:')
1622        pprint(R_symbol, num_columns=NUM_COLS)
1623        
1624        # ========================================================================
1625        # FIX: Always lambdify with all variables for consistency
1626        # ========================================================================
1627        if self.dim == 1:
1628            # Always include both x and xi in the signature
1629            R_func = lambdify((x, xi), R_symbol, modules='numpy')
1630        elif self.dim == 2:
1631            # Always include all four variables
1632            R_func = lambdify((x, y, xi, eta), R_symbol, modules='numpy')
1633        
1634        # Prepare right-hand side
1635        if self.source_terms:
1636            f_expr = sum(self.source_terms)
1637            used_vars = [v for v in spatial_vars if f_expr.has(v)]
1638            f_func = lambdify(used_vars, -f_expr, modules='numpy')
1639            if self.dim == 1:
1640                rhs = f_func(self.x_grid) if used_vars else np.zeros_like(self.x_grid)
1641            else:
1642                rhs = f_func(self.X, self.Y) if used_vars else np.zeros_like(self.X)
1643        elif self.initial_condition:
1644            raise ValueError('Initial condition should be None for stationnary equation.')
1645        else:
1646            raise ValueError('No source term provided to construct the right-hand side.')
1647        
1648        f_hat = self.fft(rhs)
1649        
1650        # ========================================================================
1651        # Application of the inverse operator
1652        # ========================================================================
1653        if self.boundary_condition == 'periodic':
1654            if self.dim == 1:
1655                # Check if optimization is possible
1656                if not R_symbol.has(x):
1657                    print('⚡ Optimization: symbol independent of x – direct product in Fourier.')
1658                    # Create wrapper that ignores x
1659                    def _R_func_optimized(kx_val):
1660                        return R_func(0.0, kx_val)  # x=0 since it doesn't matter
1661                    
1662                    R_vals = _R_func_optimized(self.KX)
1663                    u_hat = R_vals * f_hat
1664                    u = self.ifft(u_hat)
1665                else:
1666                    print('⚙️ 1D Kohn-Nirenberg Quantification')
1667                    from psiop import kohn_nirenberg_fft
1668                    u = kohn_nirenberg_fft(
1669                        u_vals=rhs,
1670                        symbol_func=R_func,  # Now has correct signature (x, xi)
1671                        x_grid=self.x_grid,
1672                        kx=self.kx,
1673                        fft_func=self.fft,
1674                        ifft_func=self.ifft,
1675                        dim=1
1676                    )
1677                    
1678            elif self.dim == 2:
1679                if not R_symbol.has(x) and not R_symbol.has(y):
1680                    print('⚡ Optimization: Symbol independent of x and y – direct product in 2D Fourier.')
1681                    # Create wrapper that ignores x, y
1682                    def _R_func_optimized(kx_val, ky_val):
1683                        return R_func(0.0, 0.0, kx_val, ky_val)
1684                    
1685                    R_vals = _R_func_optimized(self.KX, self.KY)
1686                    u_hat = R_vals * f_hat
1687                    u = self.ifft(u_hat)
1688                else:
1689                    print('⚙️ 2D Kohn-Nirenberg Quantification')
1690                    from psiop import kohn_nirenberg_fft
1691                    u = kohn_nirenberg_fft(
1692                        u_vals=rhs,
1693                        symbol_func=R_func,  # Now has correct signature (x, y, xi, eta)
1694                        x_grid=self.x_grid,
1695                        kx=self.kx,
1696                        fft_func=self.fft,
1697                        ifft_func=self.ifft,
1698                        dim=2,
1699                        y_grid=self.y_grid,
1700                        ky=self.ky
1701                    )
1702            self.u = u
1703            return u
1704            
1705        elif self.boundary_condition == 'dirichlet':
1706            from psiop import kohn_nirenberg_nonperiodic
1707            
1708            if self.dim == 1:
1709                u = kohn_nirenberg_nonperiodic(
1710                    u_vals=rhs,
1711                    x_grid=self.x_grid,
1712                    xi_grid=self.kx,
1713                    symbol_func=R_func  # Now has correct signature (x, xi)
1714                )
1715            elif self.dim == 2:
1716                u = kohn_nirenberg_nonperiodic(
1717                    u_vals=rhs,
1718                    x_grid=(self.x_grid, self.y_grid),
1719                    xi_grid=(self.kx, self.ky),
1720                    symbol_func=R_func  # Now has correct signature (x, y, xi, eta)
1721                )
1722            self.u = u
1723            return u
1724        
1725        else:
1726            raise ValueError(f"Invalid boundary condition '{self.boundary_condition}'. Supported types are 'periodic' and 'dirichlet'.")
1727        
1728    def _step_ETD_RK4(self, u):
1729        """
1730        Perform one Exponential Time Differencing Runge-Kutta of 4th order (ETD-RK4) time step 
1731        for first-order in time PDEs of the form:
1732        
1733            ∂ₜu = L u + N(u)
1734        
1735        where L is a linear operator (possibly nonlocal or pseudo-differential), and N is a 
1736        nonlinear term treated via pseudo-spectral methods. This method evaluates the 
1737        exponential integrator up to fourth-order accuracy in time.
1738    
1739        The ETD-RK4 scheme uses four stages to approximate the integral of the variation-of-constants formula:
1740        
1741            uⁿ⁺¹ = e^(L Δt) uⁿ + Δt ∫₀¹ e^(L Δt (1 - τ)) φ(N(u(τ))) dτ
1742        
1743        where φ denotes the nonlinear contributions evaluated at intermediate stages.
1744    
1745        Parameters
1746            u (np.ndarray): Current solution in real space (physical grid values).
1747    
1748        Returns:
1749            np.ndarray: Updated solution in real space after one ETD-RK4 time step.
1750    
1751        Notes:
1752        - The linear part L is diagonal in Fourier space and precomputed as self.L(k).
1753        - Nonlinear terms are evaluated in physical space and transformed via FFT.
1754        - The functions φ₁(z) and φ₂(z) are entire functions arising from the ETD scheme:
1755          
1756              φ₁(z) = (eᶻ - 1)/z   if z ≠ 0
1757                     = 1            if z = 0
1758    
1759              φ₂(z) = (eᶻ - 1 - z)/z²   if z ≠ 0
1760                     = ½              if z = 0
1761    
1762        - This implementation assumes periodic boundary conditions and uses spectral differentiation via FFT.
1763        - See Hochbruck & Ostermann (2010) for theoretical background on exponential integrators.
1764    
1765        See Also:
1766            step_ETD_RK4_order2 : For second-order in time equations.
1767            psiOp_apply           : For applying pseudo-differential operators.
1768            apply_nonlinear      : For handling nonlinear terms in the PDE.
1769        """
1770        dt = self.dt
1771        L_fft = self.L(self.KX) if self.dim == 1 else self.L(self.KX, self.KY)
1772    
1773        E  = np.exp(dt * L_fft)
1774        E2 = np.exp(dt * L_fft / 2)
1775    
1776        def phi1(z):
1777            return np.where(np.abs(z) > 1e-12, (np.exp(z) - 1) / z, 1.0)
1778    
1779        def phi2(z):
1780            return np.where(np.abs(z) > 1e-12, (np.exp(z) - 1 - z) / z**2, 0.5)
1781    
1782        phi1_dtL = phi1(dt * L_fft)
1783        phi2_dtL = phi2(dt * L_fft)
1784    
1785        fft = self.fft
1786        ifft = self.ifft
1787    
1788        u_hat = fft(u)
1789        N1 = fft(self._apply_nonlinear(u))
1790    
1791        a = ifft(E2 * (u_hat + 0.5 * dt * N1 * phi1_dtL))
1792        N2 = fft(self._apply_nonlinear(a))
1793    
1794        b = ifft(E2 * (u_hat + 0.5 * dt * N2 * phi1_dtL))
1795        N3 = fft(self._apply_nonlinear(b))
1796    
1797        c = ifft(E * (u_hat + dt * N3 * phi1_dtL))
1798        N4 = fft(self._apply_nonlinear(c))
1799    
1800        u_new_hat = E * u_hat + dt * (
1801            N1 * phi1_dtL + 2 * (N2 + N3) * phi2_dtL + N4 * phi1_dtL
1802        ) / 6
1803    
1804        return ifft(u_new_hat)
1805
1806    def _step_ETD_RK4_order2(self, u, v):
1807        """
1808        Perform one time step of the Exponential Time Differencing Runge-Kutta 4th-order (ETD-RK4) scheme for second-order PDEs.
1809    
1810        This method evolves the solution u and its time derivative v forward in time by one step using the ETD-RK4 integrator. 
1811        It is designed for systems of the form:
1812        
1813            ∂ₜ²u = L u + N(u)
1814            
1815        where L is a linear operator and N is a nonlinear term computed via self._apply_nonlinear.
1816        
1817        The exponential integrator handles the linear part exactly in Fourier space, while the nonlinear terms are integrated 
1818        using a fourth-order Runge-Kutta-like approach. This ensures high accuracy and stability for stiff systems.
1819    
1820        Parameters:
1821            u (np.ndarray): Current solution array in real space.
1822            v (np.ndarray): Current time derivative of the solution (∂ₜu) in real space.
1823    
1824        Returns:
1825            tuple: (u_new, v_new), updated solution and its time derivative after one time step.
1826    
1827        Notes:
1828            - Assumes periodic boundary conditions and uses FFT-based spectral methods.
1829            - Handles both 1D and 2D problems seamlessly.
1830            - Uses phi functions to compute exponential integrators efficiently.
1831            - Suitable for wave equations and other second-order evolution equations with stiffness.
1832        """
1833        dt = self.dt
1834    
1835        L_fft = self.L(self.KX) if self.dim == 1 else self.L(self.KX, self.KY)
1836        fft = self.fft
1837        ifft = self.ifft
1838    
1839        def rhs(u_val):
1840            return ifft(L_fft * fft(u_val)) + self._apply_nonlinear(u_val, is_v=False)
1841    
1842        # Stage A
1843        A = rhs(u)
1844        ua = u + 0.5 * dt * v
1845        va = v + 0.5 * dt * A
1846    
1847        # Stage B
1848        B = rhs(ua)
1849        ub = u + 0.5 * dt * va
1850        vb = v + 0.5 * dt * B
1851    
1852        # Stage C
1853        C = rhs(ub)
1854        uc = u + dt * vb
1855    
1856        # Stage D
1857        D = rhs(uc)
1858    
1859        # Final update
1860        u_new = u + dt * v + (dt**2 / 6.0) * (A + 2*B + 2*C + D)
1861        v_new = v + (dt / 6.0) * (A + 2*B + 2*C + D)
1862    
1863        return u_new, v_new
1864
1865    def _check_cfl_condition(self):
1866        """
1867        Check the CFL (Courant–Friedrichs–Lewymann) condition based on group velocity 
1868        for second-order time-dependent PDEs.
1869    
1870        This method verifies whether the chosen time step dt satisfies the numerical stability 
1871        condition derived from the maximum wave propagation speed in the system. It supports both 
1872        1D and 2D problems, with or without a symbolic dispersion relation ω(k).
1873    
1874        The CFL condition ensures that information does not propagate further than one grid cell 
1875        per time step. A safety factor of 0.5 is applied by default to ensure robustness.
1876    
1877        Notes:
1878        
1879        - In 1D, the group velocity v₉(k) = dω/dk is used to compute the maximum wave speed.
1880        - In 2D, the x- and y-directional group velocities are evaluated independently.
1881        - If no dispersion relation is available, the imaginary part of the linear operator L(k) 
1882          is used as an approximation for wave speed.
1883    
1884        Raises:
1885        -------
1886        NotImplementedError: 
1887            If the spatial dimension is not 1D or 2D.
1888    
1889        Prints:
1890        -------
1891        Warning message if the current time step dt exceeds the CFL-stable limit.
1892        """
1893        print("\n*****************")
1894        print("* CFL condition *")
1895        print("*****************\n")
1896
1897        cfl_factor = 0.5  # Safety factor
1898        
1899        if self.dim == 1:
1900            if self.temporal_order == 2 and hasattr(self, 'omega'):
1901                k_vals = self.kx
1902                omega_vals = np.real(self.omega(k_vals))
1903                with np.errstate(divide='ignore', invalid='ignore'):
1904                    v_group = np.gradient(omega_vals, k_vals)
1905                max_speed = np.max(np.abs(v_group))
1906            else:
1907                max_speed = np.max(np.abs(np.imag(self.L(self.kx))))
1908            
1909            dx = self.Lx / self.Nx
1910            cfl_limit = cfl_factor * dx / max_speed if max_speed != 0 else np.inf
1911            
1912            if self.dt > cfl_limit:
1913                print(f"CFL condition violated: dt = {self.dt}, max allowed dt = {cfl_limit}")
1914    
1915        elif self.dim == 2:
1916            if self.temporal_order == 2 and hasattr(self, 'omega'):
1917                k_vals = self.kx
1918                omega_x = np.real(self.omega(k_vals, 0))
1919                omega_y = np.real(self.omega(0, k_vals))
1920                with np.errstate(divide='ignore', invalid='ignore'):
1921                    v_group_x = np.gradient(omega_x, k_vals)
1922                    v_group_y = np.gradient(omega_y, k_vals)
1923                max_speed_x = np.max(np.abs(v_group_x))
1924                max_speed_y = np.max(np.abs(v_group_y))
1925            else:
1926                max_speed_x = np.max(np.abs(np.imag(self.L(self.kx, 0))))
1927                max_speed_y = np.max(np.abs(np.imag(self.L(0, self.ky))))
1928            
1929            dx = self.Lx / self.Nx
1930            dy = self.Ly / self.Ny
1931            cfl_limit = cfl_factor / (max_speed_x / dx + max_speed_y / dy) if (max_speed_x + max_speed_y) != 0 else np.inf
1932            
1933            if self.dt > cfl_limit:
1934                print(f"CFL condition violated: dt = {self.dt}, max allowed dt = {cfl_limit}")
1935    
1936        else:
1937            raise NotImplementedError("Only 1D and 2D problems are supported.")
1938
1939    def _check_symbol_conditions(self, k_range=None, verbose=True):
1940        """
1941        Check strict analytic conditions on the linear symbol self.L_symbolic:
1942            This method evaluates three key properties of the Fourier multiplier 
1943            symbol a(k) = self.L(k), which are crucial for well-posedness, stability,
1944            and numerical efficiency. The checks apply to both 1D and 2D cases.
1945        
1946        Conditions checked:
1947        ------------------
1948        1. **Stability condition**: Re(a(k)) ≤ 0 for all k ≠ 0
1949           Ensures that the system does not exhibit exponential growth in time.
1950    
1951        2. **Dissipation condition**: Re(a(k)) ≤ -δ |k|² for large |k|
1952           Ensures sufficient damping at high frequencies to avoid oscillatory instability.
1953    
1954        3. **Growth condition**: |a(k)| ≤ C (1 + |k|)^m with m ≤ 4
1955           Ensures that the symbol does not grow too rapidly with frequency, 
1956           which would otherwise cause numerical instability or unphysical amplification.
1957    
1958        Parameters
1959        ----------
1960        k_range : tuple or None, optional
1961            Specifies the range of frequencies to test in the form (k_min, k_max, N).
1962            If None, defaults are used: [-10, 10] with 500 points in 1D, or [-10, 10] 
1963            with 100 points per axis in 2D.
1964    
1965        verbose : bool, default=True
1966            If True, prints detailed results of each condition check.
1967    
1968        Returns:
1969        --------
1970        None
1971            Output is printed directly to the console for interpretability.
1972    
1973        Notes:
1974        ------
1975        - In 2D, the radial frequency |k| = √(kx² + ky²) is used for comparisons.
1976        - The dissipation threshold assumes δ = 0.01 and p = 2 by default.
1977        - The growth ratio is compared against |k|⁴; values above 100 indicate rapid growth.
1978        - This function is typically called during solver setup or analysis phase.
1979    
1980        See Also:
1981        ---------
1982        analyze_wave_propagation : For further symbolic and numerical analysis of dispersion.
1983        plot_symbol : Visualizes the symbol's behavior over the frequency domain.
1984        """
1985        print("\n********************")
1986        print("* Symbol condition *")
1987        print("********************\n")
1988
1989    
1990        if self.dim == 1:    
1991            if k_range is None:
1992                k_vals = np.linspace(-10, 10, 500)
1993            else:
1994                k_min, k_max, N = k_range
1995                k_vals = np.linspace(k_min, k_max, N)
1996    
1997            L_vals = self.L(k_vals)
1998            k_abs = np.abs(k_vals)
1999    
2000        elif self.dim == 2:
2001            if k_range is None:
2002                k_vals = np.linspace(-10, 10, 100)
2003            else:
2004                k_min, k_max, N = k_range
2005                k_vals = np.linspace(k_min, k_max, N)
2006    
2007            KX, KY = np.meshgrid(k_vals, k_vals)
2008            L_vals = self.L(KX, KY)
2009            k_abs = np.sqrt(KX**2 + KY**2)
2010    
2011        else:
2012            raise ValueError("Only 1D and 2D dimensions are supported.")
2013
2014    
2015        re_vals = np.real(L_vals)
2016        abs_vals = np.abs(L_vals)
2017    
2018        # === Condition 1: Stability
2019        if np.any(re_vals > 1e-12):
2020            max_pos = np.max(re_vals)
2021            if verbose:
2022                print(f"❌ Stability violated: max Re(a(k)) = {max_pos}")
2023            print("Unstable symbol: Re(a(k)) > 0")
2024        elif verbose:
2025            print("✅ Spectral stability satisfied: Re(a(k)) ≤ 0")
2026    
2027        # === Condition 2: Dissipation
2028        mask = k_abs > 2
2029        if np.any(mask):
2030            re_decay = re_vals[mask]
2031            expected_decay = -0.01 * k_abs[mask]**2
2032            if np.any(re_decay > expected_decay + 1e-6):
2033                if verbose:
2034                    print("⚠️ Insufficient high-frequency dissipation")
2035            else:
2036                if verbose:
2037                    print("✅ Proper high-frequency dissipation")
2038    
2039        # === Condition 3: Growth
2040        growth_ratio = abs_vals / (1 + k_abs)**4
2041        if np.max(growth_ratio) > 100:
2042            if verbose:
2043                print("⚠️ Symbol grows rapidly: |a(k)| ≳ |k|^4")
2044        else:
2045            if verbose:
2046                print("✅ Reasonable spectral growth")
2047    
2048        if verbose:
2049            print("✔ Symbol analysis completed.")
2050
2051    def _analyze_wave_propagation(self):
2052        """
2053        Perform a detailed analysis of wave propagation characteristics based on the dispersion relation ω(k).
2054    
2055        This method visualizes key wave properties in both 1D and 2D settings:
2056        
2057        - Dispersion relation: ω(k)
2058        - Phase velocity: v_p(k) = ω(k)/|k|
2059        - Group velocity: v_g(k) = ∇ₖ ω(k)
2060        - Anisotropy in 2D (via magnitude of group velocity)
2061    
2062        The symbolic dispersion relation 'omega_symbolic' must be defined beforehand.
2063        This is typically available only for second-order-in-time equations.
2064    
2065        In 1D:
2066            Plots ω(k), v_p(k), and v_g(k) over a range of k values.
2067    
2068        In 2D:
2069            Displays heatmaps of ω(kx, ky), v_p(kx, ky), and |v_g(kx, ky)| over a 2D wavenumber grid.
2070    
2071        Raises:
2072            AttributeError: If 'omega_symbolic' is not defined, the method exits gracefully with a message.
2073    
2074        Side Effects:
2075            Generates and displays matplotlib plots.
2076        """
2077        print("\n*****************************")
2078        print("* Wave propagation analysis *")
2079        print("*****************************\n")
2080        if not hasattr(self, 'omega_symbolic'):
2081            print("❌ omega_symbolic not defined. Only available for 2nd order in time.")
2082            return
2083    
2084        if self.dim == 1:
2085            k = self.k_symbols[0]
2086            omega_func = lambdify(k, self.omega_symbolic, 'numpy')
2087    
2088            k_vals = np.linspace(-10, 10, 1000)
2089            omega_vals = omega_func(k_vals)
2090    
2091            with np.errstate(divide='ignore', invalid='ignore'):
2092                v_phase = np.where(k_vals != 0, omega_vals / k_vals, 0.0)
2093    
2094            dk = k_vals[1] - k_vals[0]
2095            v_group = np.gradient(omega_vals, dk)
2096    
2097            plt.figure(figsize=(10, 6))
2098            plt.plot(k_vals, omega_vals, label=r'$\omega(k)$')
2099            plt.plot(k_vals, v_phase, label=r'$v_p(k)$')
2100            plt.plot(k_vals, v_group, label=r'$v_g(k)$')
2101            plt.title("1D Wave Propagation Analysis")
2102            plt.xlabel("k")
2103            plt.grid()
2104            plt.legend()
2105            plt.tight_layout()
2106            plt.show()
2107    
2108        elif self.dim == 2:
2109            kx, ky = self.k_symbols
2110            omega_func = lambdify((kx, ky), self.omega_symbolic, 'numpy')
2111    
2112            k_vals = np.linspace(-10, 10, 200)
2113            KX, KY = np.meshgrid(k_vals, k_vals)
2114            K_mag = np.sqrt(KX**2 + KY**2)
2115            K_mag[K_mag == 0] = 1e-8  # Avoid division by 0
2116    
2117            omega_vals = omega_func(KX, KY)
2118            v_phase = np.real(omega_vals) / K_mag
2119    
2120            dk = k_vals[1] - k_vals[0]
2121            domega_dx = np.gradient(omega_vals, dk, axis=0)
2122            domega_dy = np.gradient(omega_vals, dk, axis=1)
2123            v_group_norm = np.sqrt(np.abs(domega_dx)**2 + np.abs(domega_dy)**2)
2124    
2125            fig, axs = plt.subplots(1, 3, figsize=(18, 5))
2126            im0 = axs[0].imshow(np.real(omega_vals), extent=[-10, 10, -10, 10],
2127                                origin='lower', cmap='viridis')
2128            axs[0].set_title(r'$\omega(k_x, k_y)$')
2129            plt.colorbar(im0, ax=axs[0])
2130    
2131            im1 = axs[1].imshow(v_phase, extent=[-10, 10, -10, 10],
2132                                origin='lower', cmap='plasma')
2133            axs[1].set_title(r'$v_p(k_x, k_y)$')
2134            plt.colorbar(im1, ax=axs[1])
2135    
2136            im2 = axs[2].imshow(v_group_norm, extent=[-10, 10, -10, 10],
2137                                origin='lower', cmap='inferno')
2138            axs[2].set_title(r'$|v_g(k_x, k_y)|$')
2139            plt.colorbar(im2, ax=axs[2])
2140    
2141            for ax in axs:
2142                ax.set_xlabel(r'$k_x$')
2143                ax.set_ylabel(r'$k_y$')
2144                ax.set_aspect('equal')
2145    
2146            plt.tight_layout()
2147            plt.show()
2148    
2149        else:
2150            print("❌ Only 1D and 2D wave analysis supported.")
2151        
2152    def _plot_symbol(self, component="abs", k_range=None, cmap="viridis"):
2153        """
2154        Visualize the spectral symbol L(k) or L(kx, ky) in 1D or 2D.
2155    
2156        This method plots the linear operator's symbolic Fourier representation 
2157        either as a function of a single wavenumber k (1D), or two wavenumbers 
2158        kx and ky (2D). The user can choose to display the real part, imaginary part, 
2159        or absolute value of the symbol.
2160    
2161        Parameters
2162        ----------
2163        component : str {'abs', 're', 'im'}
2164            Component of the symbol to visualize:
2165            
2166                - 'abs' : absolute value |a(k)|
2167                - 're'  : real part Re[a(k)]
2168                - 'im'  : imaginary part Im[a(k)]
2169                
2170        k_range : tuple (kmin, kmax, N), optional
2171            Wavenumber range for evaluation:
2172            
2173                - kmin: minimum wavenumber
2174                - kmax: maximum wavenumber
2175                - N: number of sampling points
2176                
2177            If None, defaults to [-10, 10] with high resolution.
2178        cmap : str, optional
2179            Colormap used for 2D surface plots. Default is 'viridis'.
2180    
2181        Raises
2182        ------
2183            ValueError: If the spatial dimension is not 1D or 2D.
2184    
2185        Notes:
2186            - In 1D, the symbol is plotted using a standard 2D line plot.
2187            - In 2D, a 3D surface plot is generated with color-mapped height.
2188            - Symbol evaluation uses self.L(k), which must be defined and callable.
2189        """
2190        print("\n*******************")
2191        print("* Symbol plotting *")
2192        print("*******************\n")
2193        
2194        assert component in ("abs", "re", "im"), "component must be 'abs', 're' or 'im'"
2195        
2196    
2197        if self.dim == 1:
2198            if k_range is None:
2199                k_vals = np.linspace(-10, 10, 1000)
2200            else:
2201                kmin, kmax, N = k_range
2202                k_vals = np.linspace(kmin, kmax, N)
2203            L_vals = self.L(k_vals)
2204    
2205            if component == "re":
2206                vals = np.real(L_vals)
2207                label = "Re[a(k)]"
2208            elif component == "im":
2209                vals = np.imag(L_vals)
2210                label = "Im[a(k)]"
2211            else:
2212                vals = np.abs(L_vals)
2213                label = "|a(k)|"
2214    
2215            plt.plot(k_vals, vals)
2216            plt.xlabel("k")
2217            plt.ylabel(label)
2218            plt.title(f"Spectral symbol: {label}")
2219            plt.grid(True)
2220            plt.show()
2221    
2222        elif self.dim == 2:
2223            if k_range is None:
2224                k_vals = np.linspace(-10, 10, 300)
2225            else:
2226                kmin, kmax, N = k_range
2227                k_vals = np.linspace(kmin, kmax, N)
2228    
2229            KX, KY = np.meshgrid(k_vals, k_vals)
2230            L_vals = self.L(KX, KY)
2231    
2232            if component == "re":
2233                Z = np.real(L_vals)
2234                title = "Re[a(kx, ky)]"
2235            elif component == "im":
2236                Z = np.imag(L_vals)
2237                title = "Im[a(kx, ky)]"
2238            else:
2239                Z = np.abs(L_vals)
2240                title = "|a(kx, ky)|"
2241    
2242            fig = plt.figure(figsize=(8, 6))
2243            ax = fig.add_subplot(111, projection='3d')
2244        
2245            surf = ax.plot_surface(KX, KY, Z, cmap=cmap, edgecolor='none', antialiased=True)
2246            fig.colorbar(surf, ax=ax, shrink=0.6)
2247        
2248            ax.set_xlabel("kx")
2249            ax.set_ylabel("ky")
2250            ax.set_zlabel(title)
2251            ax.set_title(f"2D spectral symbol: {title}")
2252            plt.tight_layout()
2253            plt.show()
2254    
2255        else:
2256            raise ValueError("Only 1D and 2D supported.")
2257
2258    def _compute_energy(self):
2259        """
2260        Compute the total energy of the wave equation solution for second-order temporal PDEs. 
2261        The energy is defined as:
2262            E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹ᐟ²u|² ] dx
2263        where L is the linear operator associated with the spatial part of the PDE,
2264        and L¹ᐟ² denotes its square root in Fourier space.
2265    
2266        This method supports both 1D and 2D problems and is only meaningful when 
2267        self.temporal_order == 2 (second-order time derivative).
2268    
2269        Returns
2270        -------
2271        float or None: 
2272            Total energy at current time step. Returns None if the temporal order is not 2 or if no valid velocity data (v_prev) is available.
2273    
2274        Notes
2275        -----
2276        - Uses FFT-based spectral differentiation to compute the spatial contributions.
2277        - Assumes periodic boundary conditions.
2278        - Handles both real and complex-valued solutions.
2279        """
2280        if self.temporal_order != 2 or self.v_prev is None:
2281            return None
2282    
2283        u = self.u_prev
2284        v = self.v_prev
2285    
2286        # Fourier transform of u
2287        u_hat = self.fft(u)
2288    
2289        if self.dim == 1:
2290            # 1D case
2291            L_vals = self.L(self.KX)
2292            sqrt_L = np.sqrt(np.abs(L_vals))
2293            Lu_hat = sqrt_L * u_hat  # Apply sqrt(|L(k)|) in Fourier space
2294            Lu = self.ifft(Lu_hat)
2295    
2296            dx = self.Lx / self.Nx
2297            energy_density = 0.5 * (np.abs(v)**2 + np.abs(Lu)**2)
2298            total_energy = np.sum(energy_density) * dx
2299    
2300        elif self.dim == 2:
2301            # 2D case
2302            L_vals = self.L(self.KX, self.KY)
2303            sqrt_L = np.sqrt(np.abs(L_vals))
2304            Lu_hat = sqrt_L * u_hat
2305            Lu = self.ifft(Lu_hat)
2306    
2307            dx = self.Lx / self.Nx
2308            dy = self.Ly / self.Ny
2309            energy_density = 0.5 * (np.abs(v)**2 + np.abs(Lu)**2)
2310            total_energy = np.sum(energy_density) * dx * dy
2311    
2312        else:
2313            raise ValueError("Unsupported dimension for u.")
2314    
2315        return total_energy
2316
2317    def plot_energy(self, log=False):
2318        """
2319        Plot the time evolution of the total energy for wave equations. 
2320        Visualizes the energy computed during simulation for both 1D and 2D cases. 
2321        Requires temporal_order=2 and prior execution of compute_energy() during solve().
2322        
2323        Parameters:
2324            log : bool
2325                If True, displays energy on a logarithmic scale to highlight exponential decay/growth.
2326        
2327        Notes:
2328            - Energy is defined as E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹⸍²u|² ] dx
2329            - Only available if energy monitoring was activated in solve()
2330            - Automatically skips plotting if no energy data is available
2331        
2332        Displays:
2333            - Time vs. Total Energy plot with grid and legend
2334            - Appropriate axis labels and dimensional context (1D/2D)
2335            - Logarithmic or linear scaling based on input parameter
2336        """
2337        if not hasattr(self, 'energy_history') or not self.energy_history:
2338            print("No energy data recorded. Call compute_energy() within solve().")
2339            return
2340    
2341        # Time vector for plotting
2342        t = np.linspace(0, self.Lt, len(self.energy_history))
2343    
2344        # Create the figure
2345        plt.figure(figsize=(6, 4))
2346        if log:
2347            plt.semilogy(t, self.energy_history, label="Energy (log scale)")
2348        else:
2349            plt.plot(t, self.energy_history, label="Energy")
2350    
2351        # Axis labels and title
2352        plt.xlabel("Time")
2353        plt.ylabel("Total energy")
2354        plt.title("Energy evolution ({}D)".format(self.dim))
2355    
2356        # Display options
2357        plt.grid(True)
2358        plt.legend()
2359        plt.tight_layout()
2360        plt.show()
2361
2362    def show_stationary_solution(self, u=None, component='abs', cmap='viridis'):
2363        """
2364        Display the stationary solution computed by solve_stationary_psiOp.
2365
2366        This method visualizes the solution of a pseudo-differential equation 
2367        solved in stationary mode. It supports both 1D and 2D spatial domains, 
2368        with options to display different components of the solution (real, 
2369        imaginary, absolute value, or phase).
2370
2371        Parameters
2372        ----------
2373        u : ndarray, optional
2374            Precomputed solution array. If None, calls solve_stationary_psiOp() 
2375            to compute the solution.
2376        component : str, optional {'real', 'imag', 'abs', 'angle'}
2377            Component of the complex-valued solution to display:
2378            - 'real': Real part
2379            - 'imag': Imaginary part
2380            - 'abs' : Absolute value (modulus)
2381            - 'angle' : Phase (argument)
2382        cmap : str, optional
2383            Colormap used for 2D visualization (default: 'viridis').
2384
2385        Raises
2386        ------
2387        ValueError
2388            If an invalid component is specified or if the spatial dimension 
2389            is not supported (only 1D and 2D are implemented).
2390
2391        Notes
2392        -----
2393        - In 1D, the solution is displayed using a standard line plot.
2394        - In 2D, the solution is visualized as a 3D surface plot.
2395        """
2396        def _get_component(u):
2397            if component == 'real':
2398                return np.real(u)
2399            elif component == 'imag':
2400                return np.imag(u)
2401            elif component == 'abs':
2402                return np.abs(u)
2403            elif component == 'angle':
2404                return np.angle(u)
2405            else:
2406                raise ValueError("Invalid component")
2407                
2408        if u is None:
2409            u = self.solve_stationary_psiOp()
2410
2411        if self.dim == 1:
2412            # Plot the solution in 1D
2413            plt.figure(figsize=(8, 4))
2414            plt.plot(self.x_grid, get_component(u), label=f'{component} of u')
2415            plt.xlabel('x')
2416            plt.ylabel(f'{component} of u')
2417            plt.title('Stationary solution (1D)')
2418            plt.grid(True)
2419            plt.legend()
2420            plt.tight_layout()
2421            plt.show()
2422    
2423        elif self.dim == 2:
2424            fig = plt.figure(figsize=(12, 6))
2425            ax = fig.add_subplot(111, projection='3d')
2426            ax.set_xlabel('x')
2427            ax.set_ylabel('y')
2428            ax.set_zlabel(f'{component.title()} of u')
2429            plt.title('Stationary solution (2D)')    
2430            data0 = get_component(u)
2431            ax.plot_surface(self.X, self.Y, data0, cmap='viridis')
2432            plt.tight_layout()
2433            plt.show()
2434    
2435        else:
2436            raise ValueError("Only 1D and 2D display are supported.")
2437
2438    def animate(self, component='abs', overlay='contour', mode='surface'):
2439        """
2440        Create an animated plot of the solution evolution over time.
2441    
2442        This method generates a dynamic visualization of the stored solution frames
2443        `self.frames`. It supports:
2444          - 1D line animation (unchanged),
2445          - 2D surface animation (original behavior, 'surface'),
2446          - 2D image animation using imshow (new, 'imshow') which is faster and
2447            often clearer for large grids.
2448    
2449        Parameters
2450        ----------
2451        component : str, optional, one of {'real', 'imag', 'abs', 'angle'}
2452            Which component of the complex field to visualize:
2453              - 'real'  : Re(u)
2454              - 'imag'  : Im(u)
2455              - 'abs'   : |u|
2456              - 'angle' : arg(u)
2457            Default is 'abs'.
2458    
2459        overlay : str or None, optional, one of {'contour', 'front', None}
2460            For 2D modes only. If None, no overlay is drawn.
2461              - 'contour' : draw contour lines on top (or beneath for 3D surface)
2462              - 'front'   : detect and mark wavefronts using gradient maxima
2463            Default is 'contour'.
2464    
2465        mode : str, optional, one of {'surface', 'imshow'}
2466            2D rendering mode. 'surface' keeps the original 3D surface plot.
2467            'imshow' draws a 2D raster (faster, often more readable).
2468            Default is 'surface' for backward compatibility.
2469    
2470        Returns
2471        -------
2472        FuncAnimation
2473            A Matplotlib `FuncAnimation` instance (you can display it in a notebook
2474            or save it to file).
2475    
2476        Notes
2477        -----
2478        - The method uses the same time-mapping logic as before (linear sampling of
2479          stored frames to animation frames).
2480        - For 'angle' the color scale is fixed between -π and π.
2481        - For other components, color scaling is by default dynamically adapted per
2482          frame in 'imshow' mode (this avoids extreme clipping if amplitudes vary).
2483        - Overlays are updated cleanly: previous contour/scatter artists are removed
2484          before drawing the next frame to avoid memory/visual accumulation.
2485        - Animation interval is 50 ms per frame (unchanged).
2486        """
2487        def _get_component(u):
2488            if component == 'real':
2489                return np.real(u)
2490            elif component == 'imag':
2491                return np.imag(u)
2492            elif component == 'abs':
2493                return np.abs(u)
2494            elif component == 'angle':
2495                return np.angle(u)
2496            else:
2497                raise ValueError("Invalid component: choose 'real','imag','abs' or 'angle'")
2498    
2499        print("\n*********************")
2500        print("* Solution plotting *")
2501        print("*********************\n")
2502    
2503        # === Calculate time vector of stored frames ===
2504        save_interval = max(1, self.Nt // self.n_frames)
2505        frame_times = np.arange(0, self.Lt + self.dt, save_interval * self.dt)
2506    
2507        # === Target times for animation ===
2508        target_times = np.linspace(0, self.Lt, self.n_frames // 2)
2509    
2510        # Map target times to nearest frame indices
2511        frame_indices = [np.argmin(np.abs(frame_times - t)) for t in target_times]
2512    
2513        # -------------------------
2514        # 1D case (unchanged logic)
2515        # -------------------------
2516        if self.dim == 1:
2517            fig, ax = plt.subplots()
2518            initial = get_component(self.frames[0])
2519            line, = ax.plot(self.X, np.real(initial) if np.iscomplexobj(initial) else initial)
2520            ax.set_ylim(np.min(initial), np.max(initial))
2521            ax.set_xlabel('x')
2522            ax.set_ylabel(f'{component} of u')
2523            ax.set_title('Initial condition')
2524            plt.tight_layout()
2525    
2526            def _update_1d(frame_number):
2527                frame = frame_indices[frame_number]
2528                ydata = get_component(self.frames[frame])
2529                ydata_real = np.real(ydata) if np.iscomplexobj(ydata) else ydata
2530                line.set_ydata(ydata_real)
2531                ax.set_ylim(np.min(ydata_real), np.max(ydata_real))
2532                current_time = target_times[frame_number]
2533                ax.set_title(f't = {current_time:.2f}')
2534                return (line,)
2535    
2536            ani = FuncAnimation(fig, update_1d, frames=len(target_times), interval=50)
2537            return ani
2538    
2539        # -------------------------
2540        # 2D case
2541        # -------------------------
2542        # Validate mode
2543        if mode not in ('surface', 'imshow'):
2544            raise ValueError("Invalid mode: choose 'surface' or 'imshow'")
2545    
2546        # Common data
2547        data0 = get_component(self.frames[0])
2548    
2549        if mode == 'surface':
2550            # original surface behavior, but ensure clean updates
2551            fig = plt.figure(figsize=(14, 8))
2552            ax = fig.add_subplot(111, projection='3d')
2553            ax.set_xlabel('x')
2554            ax.set_ylabel('y')
2555            ax.set_zlabel(f'{component.title()} of u')
2556            ax.zaxis.labelpad = 0
2557            ax.set_title('Initial condition')
2558    
2559            surf = ax.plot_surface(self.X, self.Y, data0, cmap='viridis')
2560            plt.tight_layout()
2561    
2562            def _update_surface(frame_number):
2563                frame = frame_indices[frame_number]
2564                current_data = get_component(self.frames[frame])
2565                z_offset = np.max(current_data) + 0.05 * (np.max(current_data) - np.min(current_data))
2566    
2567                ax.clear()
2568                surf_obj = ax.plot_surface(self.X, self.Y, current_data,
2569                                           cmap='viridis',
2570                                           vmin=(-np.pi if component == 'angle' else None),
2571                                           vmax=(np.pi if component == 'angle' else None))
2572                # overlays
2573                if overlay == 'contour':
2574                    # place contours slightly below the surface (use offset)
2575                    try:
2576                        ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool', offset=z_offset)
2577                    except Exception:
2578                        # fallback: simple contour without offset if not supported
2579                        ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool')
2580    
2581                elif overlay == 'front':
2582                    dx = self.x_grid[1] - self.x_grid[0]
2583                    dy = self.y_grid[1] - self.y_grid[0]
2584                    # numpy.gradient: axis0 -> y spacing, axis1 -> x spacing
2585                    du_dy, du_dx = np.gradient(current_data, dy, dx)
2586                    grad_norm = np.sqrt(du_dx**2 + du_dy**2)
2587                    local_max = (grad_norm == maximum_filter(grad_norm, size=5))
2588                    if np.max(grad_norm) > 0:
2589                        normalized = grad_norm[local_max] / np.max(grad_norm)
2590                    else:
2591                        normalized = np.zeros(np.count_nonzero(local_max))
2592                    colors = cm.plasma(normalized)
2593                    ax.scatter(self.X[local_max], self.Y[local_max],
2594                               z_offset * np.ones_like(self.X[local_max]),
2595                               color=colors, s=10, alpha=0.8)
2596    
2597                ax.set_xlabel('x')
2598                ax.set_ylabel('y')
2599                ax.set_zlabel(f'{component.title()} of u')
2600                current_time = target_times[frame_number]
2601                ax.set_title(f'Solution at t = {current_time:.2f}')
2602                return (surf_obj,)
2603    
2604            ani = FuncAnimation(fig, update_surface, frames=len(target_times), interval=50)
2605            return ani
2606    
2607        else:  # mode == 'imshow'
2608            fig, ax = plt.subplots(figsize=(7, 6))
2609            ax.set_xlabel('x')
2610            ax.set_ylabel('y')
2611            ax.set_title('Initial condition')
2612    
2613            # extent uses physical coordinates so axes show real x/y values
2614            extent = [self.x_grid[0], self.x_grid[-1], self.y_grid[0], self.y_grid[-1]]
2615    
2616            if component == 'angle':
2617                vmin, vmax = -np.pi, np.pi
2618                cmap = 'twilight'
2619            else:
2620                vmin, vmax = np.min(data0), np.max(data0)
2621                cmap = 'viridis'
2622    
2623            im = ax.imshow(data0, extent=extent, origin='lower', cmap=cmap,
2624                           vmin=vmin, vmax=vmax, aspect='auto')
2625            cbar = fig.colorbar(im, ax=ax)
2626            cbar.set_label(f"{component} of u")
2627            plt.tight_layout()
2628    
2629            # containers for dynamic overlay artists (stored on function object)
2630            # update_im.contour_art and update_im.scatter_art will be created dynamically
2631    
2632            def _update_im(frame_number):
2633                frame = frame_indices[frame_number]
2634                current_data = get_component(self.frames[frame])
2635    
2636                # update raster
2637                im.set_data(current_data)
2638                if component != 'angle':
2639                    # dynamic per-frame scaling (keeps contrast when amplitude varies)
2640                    cmin = np.nanmin(current_data)
2641                    cmax = np.nanmax(current_data)
2642                    # avoid identical vmin==vmax
2643                    if cmax > cmin:
2644                        im.set_clim(cmin, cmax)
2645    
2646                # remove previous contour if exists
2647                if overlay == 'contour':
2648                    if hasattr(update_im, 'contour_art') and update_im.contour_art is not None:
2649                        for coll in update_im.contour_art.collections:
2650                            try:
2651                                coll.remove()
2652                            except Exception:
2653                                pass
2654                        update_im.contour_art = None
2655                    # draw new contours (use meshgrid coords)
2656                    try:
2657                        update_im.contour_art = ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool')
2658                    except Exception:
2659                        # fallback: contour with axis coordinates (x_grid, y_grid)
2660                        Xc, Yc = np.meshgrid(self.x_grid, self.y_grid)
2661                        update_im.contour_art = ax.contour(Xc, Yc, current_data, levels=10, cmap='cool')
2662    
2663                # remove previous scatter if exists
2664                if overlay == 'front':
2665                    if hasattr(update_im, 'scatter_art') and update_im.scatter_art is not None:
2666                        try:
2667                            update_im.scatter_art.remove()
2668                        except Exception:
2669                            pass
2670                        update_im.scatter_art = None
2671    
2672                    dx = self.x_grid[1] - self.x_grid[0]
2673                    dy = self.y_grid[1] - self.y_grid[0]
2674                    du_dy, du_dx = np.gradient(current_data, dy, dx)
2675                    grad_norm = np.sqrt(du_dx**2 + du_dy**2)
2676                    local_max = (grad_norm == maximum_filter(grad_norm, size=5))
2677                    if np.max(grad_norm) > 0:
2678                        normalized = grad_norm[local_max] / np.max(grad_norm)
2679                    else:
2680                        normalized = np.zeros(np.count_nonzero(local_max))
2681                    colors = cm.plasma(normalized)
2682                    update_im.scatter_art = ax.scatter(self.X[local_max], self.Y[local_max],
2683                                                       c=colors, s=10, alpha=0.8)
2684    
2685                current_time = target_times[frame_number]
2686                ax.set_title(f'Solution at t = {current_time:.2f}')
2687                # return main image plus any overlay artists present so Matplotlib can redraw them
2688                artists = [im]
2689                if overlay == 'contour' and hasattr(update_im, 'contour_art') and update_im.contour_art is not None:
2690                    artists.extend(update_im.contour_art.collections)
2691                if overlay == 'front' and hasattr(update_im, 'scatter_art') and update_im.scatter_art is not None:
2692                    artists.append(update_im.scatter_art)
2693                return tuple(artists)
2694    
2695            ani = FuncAnimation(fig, update_im, frames=len(target_times), interval=50)
2696            return ani
2697
2698    def test(self, u_exact, t_eval=None, norm='relative', threshold=1e-2, component='real'):
2699        """
2700        Test the solver against an exact solution.
2701
2702        This method quantitatively compares the numerical solution with a provided exact solution 
2703        at a specified time using either relative or absolute error norms. It supports both 
2704        stationary and time-dependent problems in 1D and 2D. If enabled, it also generates plots 
2705        of the solution, exact solution, and pointwise error.
2706
2707        Parameters
2708        ----------
2709        u_exact : callable
2710            Exact solution function taking spatial coordinates and optionally time as arguments.
2711        t_eval : float, optional
2712            Time at which to compare solutions. For non-stationary problems, defaults to final time Lt.
2713            Ignored for stationary problems.
2714        norm : str {'relative', 'absolute'}
2715            Type of error norm used in comparison.
2716        threshold : float
2717            Acceptable error threshold; raises an assertion if exceeded.
2718        plot : bool
2719            Whether to display visual comparison plots (default: True).
2720        component : str {'real', 'imag', 'abs'}
2721            Component of the solution to compare and visualize.
2722
2723        Raises
2724        ------
2725        ValueError
2726            If unsupported dimension is encountered or requested evaluation time exceeds simulation duration.
2727        AssertionError
2728            If computed error exceeds the given threshold.
2729
2730        Prints
2731        ------
2732        - Information about the closest available frame to the requested evaluation time.
2733        - Computed error value and comparison to threshold.
2734
2735        Notes
2736        -----
2737        - For time-dependent problems, the solution is extracted from precomputed frames.
2738        - Plots are adapted to spatial dimension: line plots for 1D, image plots for 2D.
2739        - The method ensures consistent handling of real, imaginary, and magnitude components.
2740        """
2741        if self.is_stationary:
2742            print("Testing a stationary solution.")
2743            u_num = self.u
2744    
2745            # Compute exact solution
2746            if self.dim == 1:
2747                u_ex = u_exact(self.X)
2748            elif self.dim == 2:
2749                u_ex = u_exact(self.X, self.Y)
2750            else:
2751                raise ValueError("Unsupported dimension.")
2752            actual_t = None
2753        else:
2754            if t_eval is None:
2755                t_eval = self.Lt
2756    
2757            save_interval = max(1, self.Nt // self.n_frames)
2758            frame_times = np.arange(0, self.Lt + self.dt, save_interval * self.dt)
2759            frame_index = np.argmin(np.abs(frame_times - t_eval))
2760            actual_t = frame_times[frame_index]
2761            print(f"Closest available time to t_eval={t_eval}: {actual_t}")
2762    
2763            if frame_index >= len(self.frames):
2764                raise ValueError(f"Time t = {t_eval} exceeds simulation duration.")
2765    
2766            u_num = self.frames[frame_index]
2767    
2768            # Compute exact solution at the actual time
2769            if self.dim == 1:
2770                u_ex = u_exact(self.X, actual_t)
2771            elif self.dim == 2:
2772                u_ex = u_exact(self.X, self.Y, actual_t)
2773            else:
2774                raise ValueError("Unsupported dimension.")
2775    
2776        # Select component
2777        if component == 'real':
2778            diff = np.real(u_num) - np.real(u_ex)
2779            ref = np.real(u_ex)
2780        elif component == 'imag':
2781            diff = np.imag(u_num) - np.imag(u_ex)
2782            ref = np.imag(u_ex)
2783        elif component == 'abs':
2784            diff = np.abs(u_num) - np.abs(u_ex)
2785            ref = np.abs(u_ex)
2786        else:
2787            raise ValueError("Invalid component.")
2788    
2789        # Compute error
2790        if norm == 'relative':
2791            error = np.linalg.norm(diff) / np.linalg.norm(ref)
2792        elif norm == 'absolute':
2793            error = np.linalg.norm(diff)
2794        else:
2795            raise ValueError("Unknown norm type.")
2796    
2797        label_time = f"t = {actual_t}" if actual_t is not None else ""
2798        print(f"Test error {label_time}: {error:.3e}")
2799        assert error < threshold, f"Error too large {label_time}: {error:.3e}"
2800    
2801        # Plot
2802        if self.plot:
2803            if self.dim == 1:
2804                plt.figure(figsize=(12, 6))
2805                plt.subplot(2, 1, 1)
2806                plt.plot(self.X, np.real(u_num), label='Numerical')
2807                plt.plot(self.X, np.real(u_ex), '--', label='Exact')
2808                plt.title(f'Solution {label_time}, error = {error:.2e}')
2809                plt.legend()
2810                plt.grid()
2811    
2812                plt.subplot(2, 1, 2)
2813                plt.plot(self.X, np.abs(diff), color='red')
2814                plt.title('Absolute Error')
2815                plt.grid()
2816                plt.tight_layout()
2817                plt.show()
2818            else:
2819                extent = [-self.Lx/2, self.Lx/2, -self.Ly/2, self.Ly/2]
2820                plt.figure(figsize=(15, 5))
2821                plt.subplot(1, 3, 1)
2822                plt.title("Numerical Solution")
2823                plt.imshow(np.abs(u_num), origin='lower', extent=extent, cmap='viridis')
2824                plt.colorbar()
2825    
2826                plt.subplot(1, 3, 2)
2827                plt.title("Exact Solution")
2828                plt.imshow(np.abs(u_ex), origin='lower', extent=extent, cmap='viridis')
2829                plt.colorbar()
2830    
2831                plt.subplot(1, 3, 3)
2832                plt.title(f"Error (Norm = {error:.2e})")
2833                plt.imshow(np.abs(diff), origin='lower', extent=extent, cmap='inferno')
2834                plt.colorbar()
2835                plt.tight_layout()
2836                plt.show()
2837
2838        return error

A partial differential equation (PDE) solver based on spectral methods using Fourier transforms.

This solver supports symbolic specification of PDEs via SymPy and numerical solution using high-order spectral techniques. It is designed for both linear and nonlinear time-dependent PDEs, as well as stationary pseudo-differential problems.

Key Features:

  • Symbolic PDE parsing using SymPy expressions
  • 1D and 2D spatial domains with periodic boundary conditions
  • Fourier-based spectral discretization with dealiasing
  • Temporal integration schemes:
    • Default exponential time stepping
    • ETD-RK4 (Exponential Time Differencing Runge-Kutta of 4th order)
  • Nonlinear terms handled through pseudo-spectral evaluation
  • Built-in tools for:
    • Visualization of solutions and error surfaces
    • Symbol analysis of linear and pseudo-differential operators
    • Microlocal analysis (e.g., Hamiltonian flows)
    • CFL condition checking and numerical stability diagnostics

Supported Operators:

  • Linear differential and pseudo-differential operators
  • Nonlinear terms up to second order in derivatives
  • Symbolic operator composition and adjoints
  • Asymptotic inversion of elliptic operators for stationary problems

Example Usage:

>>> from PDESolver import *
>>> u = Function('u')
>>> t, x = symbols('t x')
>>> eq = Eq(diff(u(t, x), t), diff(u(t, x), x, 2) + u(t, x)**2)
>>> def _initial(x): return np.sin(x)
>>> solver = PDESolver(eq)
>>> solver.setup(Lx=2*np.pi, Nx=128, Lt=1.0, Nt=1000, initial_condition=initial)
>>> solver.solve()
>>> ani = solver.animate()
>>> HTML(ani.to_jshtml())  # Display animation in Jupyter notebook
PDESolver(equation, time_scheme='default', dealiasing_ratio=0.6666666666666666)
 70    def __init__(self, equation, time_scheme='default', dealiasing_ratio=2/3):
 71        """
 72        Initialize the PDE solver with a given equation.
 73
 74        This method analyzes the input partial differential equation (PDE), 
 75        identifies the unknown function and its dependencies, determines whether 
 76        the problem is stationary or time-dependent, and prepares symbolic and 
 77        numerical structures for solving in spectral space.
 78
 79        Supported features:
 80        
 81        - 1D and 2D problems
 82        - Time-dependent and stationary equations
 83        - Linear and nonlinear terms
 84        - Pseudo-differential operators via `psiOp`
 85        - Source terms and boundary conditions
 86
 87        The equation is parsed to extract linear, nonlinear, source, and 
 88        pseudo-differential components. Symbolic manipulation is used to derive 
 89        the Fourier representation of linear operators when applicable.
 90
 91        Parameters
 92        ----------
 93        equation : sympy.Eq 
 94            The PDE expressed as a SymPy equation.
 95        time_scheme : str
 96            Temporal integration scheme: 
 97                - 'default' for exponential 
 98                - time-stepping or 'ETD-RK4' for fourth-order exponential 
 99                - time differencing Runge–Kutta.
100        dealiasing_ratio : float
101            Fraction of high-frequency modes to zero out 
102            during dealiasing (e.g., 2/3 for standard truncation).
103
104        Attributes initialized:
105        
106        - self.u: the unknown function (e.g., u(t, x))
107        - self.dim: spatial dimension (1 or 2)
108        - self.spatial_vars: list of spatial variables (e.g., [x] or [x, y])
109        - self.is_stationary: boolean indicating if the problem is stationary
110        - self.linear_terms: dictionary mapping derivative orders to coefficients
111        - self.nonlinear_terms: list of nonlinear expressions
112        - self.source_terms: list of source functions
113        - self.pseudo_terms: list of pseudo-differential operator expressions
114        - self.has_psi: boolean indicating presence of pseudo-differential operators
115        - self.fft / self.ifft: appropriate FFT routines based on spatial dimension
116        - self.kx, self.ky: symbolic wavenumber variables for Fourier space
117
118        Raises:
119            ValueError: If the equation does not contain exactly one unknown function,
120                        if unsupported dimensions are detected, or invalid dependencies.
121        """
122        self.time_scheme = time_scheme # 'default'  or 'ETD-RK4'
123        self.dealiasing_ratio = dealiasing_ratio
124        
125        print("\n*********************************")
126        print("* Partial differential equation *")
127        print("*********************************\n")
128        pprint(equation, num_columns=NUM_COLS)
129        
130        # Extract symbols and function from the equation
131        functions = equation.atoms(Function)
132        
133        # Ignore the wrappers psiOp and Op
134        excluded_wrappers = {'psiOp', 'Op'}
135        
136        # Extract the candidate fonctions (excluding wrappers)
137        candidate_functions = [
138            f for f in functions 
139            if f.func.__name__ not in excluded_wrappers
140        ]
141        
142        # Keep only user functions (u(x), u(x, t), etc.)
143        candidate_functions = [
144            f for f in functions
145            if isinstance(f, AppliedUndef)
146        ]
147        
148        # Stationary detection: no dependence on t
149        self.is_stationary = all(
150            not any(str(arg) == 't' for arg in f.args)
151            for f in candidate_functions
152        )
153        
154        if len(candidate_functions) != 1:
155            print("candidate_functions :", candidate_functions)
156            raise ValueError("The equation must contain exactly one unknown function")
157        
158        self.u = candidate_functions[0]
159
160        self.u_eq = self.u
161
162        args = self.u.args
163        
164        if self.is_stationary:
165            if len(args) not in (1, 2):
166                raise ValueError("Stationary problems must depend on 1 or 2 spatial variables")
167            self.spatial_vars = args
168        else:
169            if len(args) < 2 or len(args) > 3:
170                raise ValueError("The function must depend on t and at least one spatial variable (x [, y])")
171            self.t = args[0]
172            self.spatial_vars = args[1:]
173
174        self.dim = len(self.spatial_vars)
175        if self.dim == 1:
176            self.x = self.spatial_vars[0]
177            self.y = None
178        elif self.dim == 2:
179            self.x, self.y = self.spatial_vars
180        else:
181            raise ValueError("Only 1D and 2D problems are supported.")
182
183        if self.dim == 1:
184            self.fft = partial(fft, workers=FFT_WORKERS)
185            self.ifft = partial(ifft, workers=FFT_WORKERS)
186        else:
187            self.fft = partial(fft2, workers=FFT_WORKERS)
188            self.ifft = partial(ifft2, workers=FFT_WORKERS)
189            
190        # Parse the equation
191        self.linear_terms = {}
192        self.nonlinear_terms = []
193        self.symbol_terms = []
194        self.source_terms = []
195        self.pseudo_terms = []
196        self.temporal_order = 0  # Order of the temporal derivative
197        self.linear_terms, self.nonlinear_terms, self.symbol_terms, self.source_terms, self.pseudo_terms = self._parse_equation(equation)
198        # flag : pseudo‑differential operator present ?
199        self.has_psi = bool(self.pseudo_terms)
200        if self.has_psi:
201            print('⚠️  Pseudo‑differential operator detected: all other linear terms have been rejected.')
202            self.is_spatial = False
203            for coeff, expr in self.pseudo_terms:
204                if expr.has(self.x) or (self.dim == 2 and expr.has(self.y)):
205                    self.is_spatial = True
206                    break
207    
208        if self.dim == 1:
209            self.kx = symbols('kx')
210        elif self.dim == 2:
211            self.kx, self.ky = symbols('kx ky')
212    
213        # Compute linear operator
214        if not self.is_stationary:
215            self._compute_linear_operator()
216        else:
217            self.psi_ops = []
218            for coeff, sym_expr in self.pseudo_terms:
219                psi = PseudoDifferentialOperator(sym_expr, self.spatial_vars, self.u, mode='symbol')
220                self.psi_ops.append((coeff, psi))

Initialize the PDE solver with a given equation.

This method analyzes the input partial differential equation (PDE), identifies the unknown function and its dependencies, determines whether the problem is stationary or time-dependent, and prepares symbolic and numerical structures for solving in spectral space.

Supported features:

  • 1D and 2D problems
  • Time-dependent and stationary equations
  • Linear and nonlinear terms
  • Pseudo-differential operators via psiOp
  • Source terms and boundary conditions

The equation is parsed to extract linear, nonlinear, source, and pseudo-differential components. Symbolic manipulation is used to derive the Fourier representation of linear operators when applicable.

Parameters

equation : sympy.Eq The PDE expressed as a SymPy equation. time_scheme : str Temporal integration scheme: - 'default' for exponential - time-stepping or 'ETD-RK4' for fourth-order exponential - time differencing Runge–Kutta. dealiasing_ratio : float Fraction of high-frequency modes to zero out during dealiasing (e.g., 2/3 for standard truncation).

Attributes initialized:

  • self.u: the unknown function (e.g., u(t, x))
  • self.dim: spatial dimension (1 or 2)
  • self.spatial_vars: list of spatial variables (e.g., or [x, y])
  • self.is_stationary: boolean indicating if the problem is stationary
  • self.linear_terms: dictionary mapping derivative orders to coefficients
  • self.nonlinear_terms: list of nonlinear expressions
  • self.source_terms: list of source functions
  • self.pseudo_terms: list of pseudo-differential operator expressions
  • self.has_psi: boolean indicating presence of pseudo-differential operators
  • self.fft / self.ifft: appropriate FFT routines based on spatial dimension
  • self.kx, self.ky: symbolic wavenumber variables for Fourier space

Raises: ValueError: If the equation does not contain exactly one unknown function, if unsupported dimensions are detected, or invalid dependencies.

time_scheme
dealiasing_ratio
is_stationary
u
u_eq
dim
linear_terms
nonlinear_terms
symbol_terms
source_terms
pseudo_terms
temporal_order
has_psi
def setup( self, Lx, Ly=None, Nx=None, Ny=None, Lt=1.0, Nt=100, boundary_condition='periodic', initial_condition=None, initial_velocity=None, n_frames=100, plot=True):
559    def setup(self, Lx, Ly=None, Nx=None, Ny=None, Lt=1.0, Nt=100, boundary_condition='periodic',
560              initial_condition=None, initial_velocity=None, n_frames=100, plot=True):
561        """
562        Configure the spatial/temporal grid and initialize the solution field.
563    
564        This method sets up the computational domain, initializes spatial and temporal grids,
565        applies boundary conditions, and prepares symbolic and numerical operators.
566        It also performs essential analyses such as:
567        
568            - CFL condition verification (for stability)
569            - Symbol analysis (e.g., dispersion relation, regularity)
570            - Wave propagation analysis for second-order equations
571    
572        If pseudo-differential operators (ψOp) are present, symbolic analysis is skipped
573        in favor of interactive exploration via `interactive_symbol_analysis`.
574    
575        Parameters
576        ----------
577        Lx : float
578            Size of the spatial domain along x-axis.
579        Ly : float, optional
580            Size of the spatial domain along y-axis (for 2D problems).
581        Nx : int
582            Number of spatial points along x-axis.
583        Ny : int, optional
584            Number of spatial points along y-axis (for 2D problems).
585        Lt : float, default=1.0
586            Total simulation time.
587        Nt : int, default=100
588            Number of time steps.
589        initial_condition : callable
590            Function returning the initial state u(x, 0) or u(x, y, 0).
591        initial_velocity : callable, optional
592            Function returning the initial time derivative ∂ₜu(x, 0) or ∂ₜu(x, y, 0),
593            required for second-order equations.
594        n_frames : int, default=100
595            Number of time frames to store during simulation for visualization or output.
596    
597        Raises
598        ------
599        ValueError
600            If mandatory parameters are missing (e.g., Nx not given in 1D, Ly/Ny not given in 2D).
601    
602        Notes
603        -----
604        - The spatial discretization assumes periodic boundary conditions by default.
605        - Fourier transforms are computed using real-to-complex FFTs (`scipy.fft.fft`, `fft2`).
606        - Frequency arrays (`KX`, `KY`) are defined following standard spectral conventions.
607        - Dealiasing is applied using a sharp cutoff filter at a fraction of the maximum frequency.
608        - For second-order equations, initial acceleration is derived from the governing operator.
609        - Symbolic analysis includes plotting of the symbol's real/imaginary/absolute values
610          and dispersion relation.
611    
612        See Also
613        --------
614        setup_1D : Sets up internal variables for one-dimensional problems.
615        setup_2D : Sets up internal variables for two-dimensional problems.
616        initialize_conditions : Applies initial data and enforces compatibility.
617        check_cfl_condition : Verifies time step against stability constraints.
618        plot_symbol : Visualizes the linear operator’s symbol in frequency space.
619        analyze_wave_propagation : Analyzes group velocity.
620        interactive_symbol_analysis : Interactive tools for ψOp-based equations.
621        """
622        
623        # Temporal parameters
624        self.Lt, self.Nt = Lt, Nt
625        self.dt = Lt / Nt
626        self.n_frames = n_frames
627        self.frames = []
628        self.initial_condition = initial_condition
629        self.boundary_condition = boundary_condition
630        self.plot = plot
631
632        if self.boundary_condition == 'dirichlet' and not self.has_psi:
633            raise ValueError(
634                "Dirichlet boundary conditions require the equation to be defined via a pseudo-differential operator (psiOp). "
635                "Please provide an equation involving psiOp for non-periodic boundary treatment."
636            )
637    
638        # Dimension checks
639        if self.dim == 1:
640            if Nx is None:
641                raise ValueError("Nx must be specified in 1D.")
642            self._setup_1D(Lx, Nx)
643        else:
644            if None in (Ly, Ny):
645                raise ValueError("In 2D, Ly and Ny must be provided.")
646            self._setup_2D(Lx, Ly, Nx, Ny)
647    
648        # Initialization of solution and velocities
649        if not self.is_stationary:
650            self._initialize_conditions(initial_condition, initial_velocity)
651            
652        # Symbol analysis if present
653        if self.has_psi:
654            print("⚠️ For psiOp, use interactive_symbol_analysis.")
655        else:
656            if self.L_symbolic == 0:
657                print("⚠️ Linear operator is null.")
658            else:
659                self._check_cfl_condition()
660                self._check_symbol_conditions()
661                if plot:
662                	self._plot_symbol()
663                	if self.temporal_order == 2:
664                		self._analyze_wave_propagation()

Configure the spatial/temporal grid and initialize the solution field.

This method sets up the computational domain, initializes spatial and temporal grids, applies boundary conditions, and prepares symbolic and numerical operators. It also performs essential analyses such as:

- CFL condition verification (for stability)
- Symbol analysis (e.g., dispersion relation, regularity)
- Wave propagation analysis for second-order equations

If pseudo-differential operators (ψOp) are present, symbolic analysis is skipped in favor of interactive exploration via interactive_symbol_analysis.

Parameters

Lx : float Size of the spatial domain along x-axis. Ly : float, optional Size of the spatial domain along y-axis (for 2D problems). Nx : int Number of spatial points along x-axis. Ny : int, optional Number of spatial points along y-axis (for 2D problems). Lt : float, default=1.0 Total simulation time. Nt : int, default=100 Number of time steps. initial_condition : callable Function returning the initial state u(x, 0) or u(x, y, 0). initial_velocity : callable, optional Function returning the initial time derivative ∂ₜu(x, 0) or ∂ₜu(x, y, 0), required for second-order equations. n_frames : int, default=100 Number of time frames to store during simulation for visualization or output.

Raises

ValueError If mandatory parameters are missing (e.g., Nx not given in 1D, Ly/Ny not given in 2D).

Notes

  • The spatial discretization assumes periodic boundary conditions by default.
  • Fourier transforms are computed using real-to-complex FFTs (scipy.fft.fft, fft2).
  • Frequency arrays (KX, KY) are defined following standard spectral conventions.
  • Dealiasing is applied using a sharp cutoff filter at a fraction of the maximum frequency.
  • For second-order equations, initial acceleration is derived from the governing operator.
  • Symbolic analysis includes plotting of the symbol's real/imaginary/absolute values and dispersion relation.

See Also

setup_1D : Sets up internal variables for one-dimensional problems. setup_2D : Sets up internal variables for two-dimensional problems. initialize_conditions : Applies initial data and enforces compatibility. check_cfl_condition : Verifies time step against stability constraints. plot_symbol : Visualizes the linear operator’s symbol in frequency space. analyze_wave_propagation : Analyzes group velocity. interactive_symbol_analysis : Interactive tools for ψOp-based equations.

def solve(self):
1421    def solve(self):
1422        """
1423        Solve the partial differential equation numerically using spectral methods.
1424        
1425        This method evolves the solution in time using a combination of:
1426        - Fourier-based linear evolution (with dealiasing)
1427        - Nonlinear term handling via pseudo-spectral evaluation
1428        - Support for pseudo-differential operators (psiOp)
1429        - Source terms and boundary conditions
1430        
1431        The solver supports:
1432        - 1D and 2D spatial domains
1433        - First and second-order time evolution
1434        - Periodic and Dirichlet boundary conditions
1435        - Time-stepping schemes: default, ETD-RK4
1436        
1437        Returns:
1438            list[np.ndarray]: A list of solution arrays at each saved time frame.
1439        
1440        Side Effects:
1441            - Updates self.frames: stores solution snapshots
1442            - Updates self.energy_history: records total energy if enabled
1443            
1444        Algorithm Overview:
1445            For each time step:
1446                1. Evaluate source contributions (if any)
1447                2. Apply time evolution:
1448                    - Order 1:
1449                        - With psiOp: uses step_order1_with_psi
1450                        - With ETD-RK4: exponential time differencing
1451                        - Default: linear + nonlinear update
1452                    - Order 2:
1453                        - With psiOp: uses step_order2_with_psi
1454                        - With ETD-RK4: second-order exponential scheme
1455                        - Default: second-order leapfrog-style update
1456                3. Enforce boundary conditions
1457                4. Save solution snapshot periodically
1458                5. Record energy (for second-order systems without psiOp)
1459        """
1460        print('\n*******************')
1461        print('* Solving the PDE *')
1462        print('*******************\n')
1463        save_interval = max(1, self.Nt // self.n_frames)
1464        self.energy_history = []
1465        for step in range(self.Nt):
1466            if hasattr(self, 'source_terms') and self.source_terms:
1467                source_contribution = np.zeros_like(self.X, dtype=np.float64)
1468                for term in self.source_terms:
1469                    try:
1470                        if self.dim == 1:
1471                            source_func = lambdify((self.t, self.x), term, 'numpy')
1472                            source_contribution += source_func(step * self.dt, self.X)
1473                        elif self.dim == 2:
1474                            source_func = lambdify((self.t, self.x, self.y), term, 'numpy')
1475                            source_contribution += source_func(step * self.dt, self.X, self.Y)
1476                    except Exception as e:
1477                        print(f'Error evaluating source term {term}: {e}')
1478            else:
1479                source_contribution = 0
1480
1481            if self.temporal_order == 1:
1482                if self.has_psi:
1483                    u_new = self._step_order1_with_psi(source_contribution)
1484                elif hasattr(self, 'time_scheme') and self.time_scheme == 'ETD-RK4':
1485                    u_new = self._step_ETD_RK4(self.u_prev)
1486                else:
1487                    u_hat = self.fft(self.u_prev)
1488                    u_hat *= self.exp_L
1489                    u_hat *= self.dealiasing_mask
1490                    u_lin = self.ifft(u_hat)
1491                    u_nl = self._apply_nonlinear(u_lin)
1492                    u_new = u_lin + u_nl + source_contribution
1493                self._apply_boundary(u_new)
1494                self.u_prev = u_new
1495
1496            elif self.temporal_order == 2:
1497                if self.has_psi:
1498                    u_new = self._step_order2_with_psi(source_contribution)
1499                else:
1500                    if hasattr(self, 'time_scheme') and self.time_scheme == 'ETD-RK4':
1501                        u_new, v_new = self._step_ETD_RK4_order2(self.u_prev, self.v_prev)
1502                    else:
1503                        u_hat = self.fft(self.u_prev)
1504                        v_hat = self.fft(self.v_prev)
1505                        u_new_hat = self.cos_omega_dt * u_hat + self.sin_omega_dt * self.inv_omega * v_hat
1506                        v_new_hat = -self.omega_val * self.sin_omega_dt * u_hat + self.cos_omega_dt * v_hat
1507                        u_new = self.ifft(u_new_hat)
1508                        v_new = self.ifft(v_new_hat)
1509                        u_nl = self._apply_nonlinear(self.u_prev, is_v=False)
1510                        v_nl = self._apply_nonlinear(self.v_prev, is_v=True)
1511                        u_new += (u_nl + source_contribution) * self.dt ** 2 / 2
1512                        v_new += (u_nl + source_contribution) * self.dt
1513                    self._apply_boundary(u_new)
1514                    self._apply_boundary(v_new)
1515                    self.u_prev = u_new
1516                    self.v_prev = v_new
1517
1518            if step % save_interval == 0:
1519                self.frames.append(self.u_prev.copy())
1520
1521            if self.temporal_order == 2 and (not self.has_psi):
1522                E = self._compute_energy()
1523                self.energy_history.append(E)
1524
1525        return self.frames  

Solve the partial differential equation numerically using spectral methods.

This method evolves the solution in time using a combination of:

  • Fourier-based linear evolution (with dealiasing)
  • Nonlinear term handling via pseudo-spectral evaluation
  • Support for pseudo-differential operators (psiOp)
  • Source terms and boundary conditions

The solver supports:

  • 1D and 2D spatial domains
  • First and second-order time evolution
  • Periodic and Dirichlet boundary conditions
  • Time-stepping schemes: default, ETD-RK4

Returns: list[np.ndarray]: A list of solution arrays at each saved time frame.

Side Effects: - Updates self.frames: stores solution snapshots - Updates self.energy_history: records total energy if enabled

Algorithm Overview: For each time step: 1. Evaluate source contributions (if any) 2. Apply time evolution: - Order 1: - With psiOp: uses step_order1_with_psi - With ETD-RK4: exponential time differencing - Default: linear + nonlinear update - Order 2: - With psiOp: uses step_order2_with_psi - With ETD-RK4: second-order exponential scheme - Default: second-order leapfrog-style update 3. Enforce boundary conditions 4. Save solution snapshot periodically 5. Record energy (for second-order systems without psiOp)

def solve_stationary_psiOp(self, order=3):
1527    def solve_stationary_psiOp(self, order=3):
1528        """
1529        Solve stationary pseudo-differential equations of the form P[u] = f(x) or P[u] = f(x,y) using asymptotic inversion.
1530    
1531        This method computes the solution to a stationary (time-independent) pseudo-differential equation
1532        where the operator P is defined via symbolic expressions (psiOp). It constructs an asymptotic right inverse R 
1533        such that P∘R ≈ Id, then applies it to the source term f using either direct Fourier multiplication 
1534        (when the symbol is spatially independent) or Kohn–Nirenberg quantization (when spatial dependence is present).
1535    
1536        The inversion is based on the principal symbol of the operator and its asymptotic expansion up to the given order.
1537        Ellipticity of the symbol is checked numerically before inversion to ensure well-posedness.
1538    
1539        Parameters
1540        ----------
1541        order : int, default=3
1542            Order of the asymptotic expansion used to construct the right inverse of the pseudo-differential operator.
1543        method : str, optional
1544            Inversion strategy:
1545            - 'diagonal' (default): Fast approximate inversion using diagonal operators in frequency space.
1546            - 'full'                : Pointwise exact inversion (slower but more accurate).
1547    
1548        Returns
1549        -------
1550        ndarray
1551            The computed solution u(x) in 1D or u(x, y) in 2D as a NumPy array over the spatial grid.
1552    
1553        Raises
1554        ------
1555        ValueError
1556            If no pseudo-differential operator (psiOp) is defined.
1557            If linear or nonlinear terms other than psiOp are present.
1558            If the symbol is not elliptic on the grid.
1559            If no source term is provided for the right-hand side.
1560    
1561        Notes
1562        -----
1563        - The method assumes the problem is fully stationary: time derivatives must be absent.
1564        - Requires the equation to be purely pseudo-differential (no Op, Derivative, or nonlinear terms).
1565        - Symbol evaluation and inversion are dimension-aware (supports both 1D and 2D problems).
1566        - Supports optimization paths when the symbol does not depend on spatial variables.
1567    
1568        See Also
1569        --------
1570        right_inverse_asymptotic : Constructs the asymptotic inverse of the pseudo-differential operator.
1571        kohn_nirenberg           : Numerical implementation of general pseudo-differential operators.
1572        is_elliptic_numerically  : Verifies numerical ellipticity of the symbol.
1573        """
1574
1575        print("\n*******************************")
1576        print("* Solving the stationnary PDE *")
1577        print("*******************************\n")
1578        print("boundary condition: ",self.boundary_condition)
1579        
1580
1581        if not self.has_psi:
1582            raise ValueError("Only supports problems with psiOp.")
1583    
1584        if self.linear_terms or self.nonlinear_terms:
1585            raise ValueError("Stationary psiOp problems must be linear and purely pseudo-differential.")
1586
1587        if self.boundary_condition not in ('periodic', 'dirichlet'):
1588            raise ValueError(
1589                "For stationary PDEs, boundary conditions must be explicitly defined. "
1590                "Supported types are 'periodic' and 'dirichlet'."
1591            )    
1592            
1593        if self.dim == 1:
1594            x = self.x
1595            xi = symbols('xi', real=True)
1596            spatial_vars = (x,)
1597            freq_vars = (xi,)
1598            X, KX = self.X, self.KX
1599        elif self.dim == 2:
1600            x, y = self.x, self.y
1601            xi, eta = symbols('xi eta', real=True)
1602            spatial_vars = (x, y)
1603            freq_vars = (xi, eta)
1604            X, Y, KX, KY = self.X, self.Y, self.KX, self.KY
1605        else:
1606            raise ValueError("Unsupported spatial dimension.")
1607    
1608        total_symbol = sum(coeff * psi.expr for coeff, psi in self.psi_ops)
1609        psi_total = PseudoDifferentialOperator(total_symbol, spatial_vars, mode='symbol')
1610    
1611        # Check ellipticity
1612        if self.dim == 1:
1613            is_elliptic = psi_total.is_elliptic_numerically(X, KX)
1614        else:
1615            is_elliptic = psi_total.is_elliptic_numerically((X[:, 0], Y[0, :]), (KX[:, 0], KY[0, :]))
1616        if not is_elliptic:
1617            raise ValueError("❌ The pseudo-differential symbol is not numerically elliptic on the grid.")
1618        print("✅ Elliptic pseudo-differential symbol: inversion allowed.")
1619    
1620        R_symbol = psi_total.right_inverse_asymptotic(order=order)
1621        print('Right inverse asymptotic symbol:')
1622        pprint(R_symbol, num_columns=NUM_COLS)
1623        
1624        # ========================================================================
1625        # FIX: Always lambdify with all variables for consistency
1626        # ========================================================================
1627        if self.dim == 1:
1628            # Always include both x and xi in the signature
1629            R_func = lambdify((x, xi), R_symbol, modules='numpy')
1630        elif self.dim == 2:
1631            # Always include all four variables
1632            R_func = lambdify((x, y, xi, eta), R_symbol, modules='numpy')
1633        
1634        # Prepare right-hand side
1635        if self.source_terms:
1636            f_expr = sum(self.source_terms)
1637            used_vars = [v for v in spatial_vars if f_expr.has(v)]
1638            f_func = lambdify(used_vars, -f_expr, modules='numpy')
1639            if self.dim == 1:
1640                rhs = f_func(self.x_grid) if used_vars else np.zeros_like(self.x_grid)
1641            else:
1642                rhs = f_func(self.X, self.Y) if used_vars else np.zeros_like(self.X)
1643        elif self.initial_condition:
1644            raise ValueError('Initial condition should be None for stationnary equation.')
1645        else:
1646            raise ValueError('No source term provided to construct the right-hand side.')
1647        
1648        f_hat = self.fft(rhs)
1649        
1650        # ========================================================================
1651        # Application of the inverse operator
1652        # ========================================================================
1653        if self.boundary_condition == 'periodic':
1654            if self.dim == 1:
1655                # Check if optimization is possible
1656                if not R_symbol.has(x):
1657                    print('⚡ Optimization: symbol independent of x – direct product in Fourier.')
1658                    # Create wrapper that ignores x
1659                    def _R_func_optimized(kx_val):
1660                        return R_func(0.0, kx_val)  # x=0 since it doesn't matter
1661                    
1662                    R_vals = _R_func_optimized(self.KX)
1663                    u_hat = R_vals * f_hat
1664                    u = self.ifft(u_hat)
1665                else:
1666                    print('⚙️ 1D Kohn-Nirenberg Quantification')
1667                    from psiop import kohn_nirenberg_fft
1668                    u = kohn_nirenberg_fft(
1669                        u_vals=rhs,
1670                        symbol_func=R_func,  # Now has correct signature (x, xi)
1671                        x_grid=self.x_grid,
1672                        kx=self.kx,
1673                        fft_func=self.fft,
1674                        ifft_func=self.ifft,
1675                        dim=1
1676                    )
1677                    
1678            elif self.dim == 2:
1679                if not R_symbol.has(x) and not R_symbol.has(y):
1680                    print('⚡ Optimization: Symbol independent of x and y – direct product in 2D Fourier.')
1681                    # Create wrapper that ignores x, y
1682                    def _R_func_optimized(kx_val, ky_val):
1683                        return R_func(0.0, 0.0, kx_val, ky_val)
1684                    
1685                    R_vals = _R_func_optimized(self.KX, self.KY)
1686                    u_hat = R_vals * f_hat
1687                    u = self.ifft(u_hat)
1688                else:
1689                    print('⚙️ 2D Kohn-Nirenberg Quantification')
1690                    from psiop import kohn_nirenberg_fft
1691                    u = kohn_nirenberg_fft(
1692                        u_vals=rhs,
1693                        symbol_func=R_func,  # Now has correct signature (x, y, xi, eta)
1694                        x_grid=self.x_grid,
1695                        kx=self.kx,
1696                        fft_func=self.fft,
1697                        ifft_func=self.ifft,
1698                        dim=2,
1699                        y_grid=self.y_grid,
1700                        ky=self.ky
1701                    )
1702            self.u = u
1703            return u
1704            
1705        elif self.boundary_condition == 'dirichlet':
1706            from psiop import kohn_nirenberg_nonperiodic
1707            
1708            if self.dim == 1:
1709                u = kohn_nirenberg_nonperiodic(
1710                    u_vals=rhs,
1711                    x_grid=self.x_grid,
1712                    xi_grid=self.kx,
1713                    symbol_func=R_func  # Now has correct signature (x, xi)
1714                )
1715            elif self.dim == 2:
1716                u = kohn_nirenberg_nonperiodic(
1717                    u_vals=rhs,
1718                    x_grid=(self.x_grid, self.y_grid),
1719                    xi_grid=(self.kx, self.ky),
1720                    symbol_func=R_func  # Now has correct signature (x, y, xi, eta)
1721                )
1722            self.u = u
1723            return u
1724        
1725        else:
1726            raise ValueError(f"Invalid boundary condition '{self.boundary_condition}'. Supported types are 'periodic' and 'dirichlet'.")

Solve stationary pseudo-differential equations of the form P[u] = f(x) or P[u] = f(x,y) using asymptotic inversion.

This method computes the solution to a stationary (time-independent) pseudo-differential equation where the operator P is defined via symbolic expressions (psiOp). It constructs an asymptotic right inverse R such that P∘R ≈ Id, then applies it to the source term f using either direct Fourier multiplication (when the symbol is spatially independent) or Kohn–Nirenberg quantization (when spatial dependence is present).

The inversion is based on the principal symbol of the operator and its asymptotic expansion up to the given order. Ellipticity of the symbol is checked numerically before inversion to ensure well-posedness.

Parameters

order : int, default=3 Order of the asymptotic expansion used to construct the right inverse of the pseudo-differential operator. method : str, optional Inversion strategy: - 'diagonal' (default): Fast approximate inversion using diagonal operators in frequency space. - 'full' : Pointwise exact inversion (slower but more accurate).

Returns

ndarray The computed solution u(x) in 1D or u(x, y) in 2D as a NumPy array over the spatial grid.

Raises

ValueError If no pseudo-differential operator (psiOp) is defined. If linear or nonlinear terms other than psiOp are present. If the symbol is not elliptic on the grid. If no source term is provided for the right-hand side.

Notes

  • The method assumes the problem is fully stationary: time derivatives must be absent.
  • Requires the equation to be purely pseudo-differential (no Op, Derivative, or nonlinear terms).
  • Symbol evaluation and inversion are dimension-aware (supports both 1D and 2D problems).
  • Supports optimization paths when the symbol does not depend on spatial variables.

See Also

right_inverse_asymptotic : Constructs the asymptotic inverse of the pseudo-differential operator. kohn_nirenberg : Numerical implementation of general pseudo-differential operators. is_elliptic_numerically : Verifies numerical ellipticity of the symbol.

def plot_energy(self, log=False):
2317    def plot_energy(self, log=False):
2318        """
2319        Plot the time evolution of the total energy for wave equations. 
2320        Visualizes the energy computed during simulation for both 1D and 2D cases. 
2321        Requires temporal_order=2 and prior execution of compute_energy() during solve().
2322        
2323        Parameters:
2324            log : bool
2325                If True, displays energy on a logarithmic scale to highlight exponential decay/growth.
2326        
2327        Notes:
2328            - Energy is defined as E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹⸍²u|² ] dx
2329            - Only available if energy monitoring was activated in solve()
2330            - Automatically skips plotting if no energy data is available
2331        
2332        Displays:
2333            - Time vs. Total Energy plot with grid and legend
2334            - Appropriate axis labels and dimensional context (1D/2D)
2335            - Logarithmic or linear scaling based on input parameter
2336        """
2337        if not hasattr(self, 'energy_history') or not self.energy_history:
2338            print("No energy data recorded. Call compute_energy() within solve().")
2339            return
2340    
2341        # Time vector for plotting
2342        t = np.linspace(0, self.Lt, len(self.energy_history))
2343    
2344        # Create the figure
2345        plt.figure(figsize=(6, 4))
2346        if log:
2347            plt.semilogy(t, self.energy_history, label="Energy (log scale)")
2348        else:
2349            plt.plot(t, self.energy_history, label="Energy")
2350    
2351        # Axis labels and title
2352        plt.xlabel("Time")
2353        plt.ylabel("Total energy")
2354        plt.title("Energy evolution ({}D)".format(self.dim))
2355    
2356        # Display options
2357        plt.grid(True)
2358        plt.legend()
2359        plt.tight_layout()
2360        plt.show()

Plot the time evolution of the total energy for wave equations. Visualizes the energy computed during simulation for both 1D and 2D cases. Requires temporal_order=2 and prior execution of compute_energy() during solve().

Parameters: log : bool If True, displays energy on a logarithmic scale to highlight exponential decay/growth.

Notes: - Energy is defined as E(t) = 1/2 ∫ [ (∂ₜu)² + |L¹⸍²u|² ] dx - Only available if energy monitoring was activated in solve() - Automatically skips plotting if no energy data is available

Displays: - Time vs. Total Energy plot with grid and legend - Appropriate axis labels and dimensional context (1D/2D) - Logarithmic or linear scaling based on input parameter

def show_stationary_solution(self, u=None, component='abs', cmap='viridis'):
2362    def show_stationary_solution(self, u=None, component='abs', cmap='viridis'):
2363        """
2364        Display the stationary solution computed by solve_stationary_psiOp.
2365
2366        This method visualizes the solution of a pseudo-differential equation 
2367        solved in stationary mode. It supports both 1D and 2D spatial domains, 
2368        with options to display different components of the solution (real, 
2369        imaginary, absolute value, or phase).
2370
2371        Parameters
2372        ----------
2373        u : ndarray, optional
2374            Precomputed solution array. If None, calls solve_stationary_psiOp() 
2375            to compute the solution.
2376        component : str, optional {'real', 'imag', 'abs', 'angle'}
2377            Component of the complex-valued solution to display:
2378            - 'real': Real part
2379            - 'imag': Imaginary part
2380            - 'abs' : Absolute value (modulus)
2381            - 'angle' : Phase (argument)
2382        cmap : str, optional
2383            Colormap used for 2D visualization (default: 'viridis').
2384
2385        Raises
2386        ------
2387        ValueError
2388            If an invalid component is specified or if the spatial dimension 
2389            is not supported (only 1D and 2D are implemented).
2390
2391        Notes
2392        -----
2393        - In 1D, the solution is displayed using a standard line plot.
2394        - In 2D, the solution is visualized as a 3D surface plot.
2395        """
2396        def _get_component(u):
2397            if component == 'real':
2398                return np.real(u)
2399            elif component == 'imag':
2400                return np.imag(u)
2401            elif component == 'abs':
2402                return np.abs(u)
2403            elif component == 'angle':
2404                return np.angle(u)
2405            else:
2406                raise ValueError("Invalid component")
2407                
2408        if u is None:
2409            u = self.solve_stationary_psiOp()
2410
2411        if self.dim == 1:
2412            # Plot the solution in 1D
2413            plt.figure(figsize=(8, 4))
2414            plt.plot(self.x_grid, get_component(u), label=f'{component} of u')
2415            plt.xlabel('x')
2416            plt.ylabel(f'{component} of u')
2417            plt.title('Stationary solution (1D)')
2418            plt.grid(True)
2419            plt.legend()
2420            plt.tight_layout()
2421            plt.show()
2422    
2423        elif self.dim == 2:
2424            fig = plt.figure(figsize=(12, 6))
2425            ax = fig.add_subplot(111, projection='3d')
2426            ax.set_xlabel('x')
2427            ax.set_ylabel('y')
2428            ax.set_zlabel(f'{component.title()} of u')
2429            plt.title('Stationary solution (2D)')    
2430            data0 = get_component(u)
2431            ax.plot_surface(self.X, self.Y, data0, cmap='viridis')
2432            plt.tight_layout()
2433            plt.show()
2434    
2435        else:
2436            raise ValueError("Only 1D and 2D display are supported.")

Display the stationary solution computed by solve_stationary_psiOp.

This method visualizes the solution of a pseudo-differential equation solved in stationary mode. It supports both 1D and 2D spatial domains, with options to display different components of the solution (real, imaginary, absolute value, or phase).

Parameters

u : ndarray, optional Precomputed solution array. If None, calls solve_stationary_psiOp() to compute the solution. component : str, optional {'real', 'imag', 'abs', 'angle'} Component of the complex-valued solution to display: - 'real': Real part - 'imag': Imaginary part - 'abs' : Absolute value (modulus) - 'angle' : Phase (argument) cmap : str, optional Colormap used for 2D visualization (default: 'viridis').

Raises

ValueError If an invalid component is specified or if the spatial dimension is not supported (only 1D and 2D are implemented).

Notes

  • In 1D, the solution is displayed using a standard line plot.
  • In 2D, the solution is visualized as a 3D surface plot.
def animate(self, component='abs', overlay='contour', mode='surface'):
2438    def animate(self, component='abs', overlay='contour', mode='surface'):
2439        """
2440        Create an animated plot of the solution evolution over time.
2441    
2442        This method generates a dynamic visualization of the stored solution frames
2443        `self.frames`. It supports:
2444          - 1D line animation (unchanged),
2445          - 2D surface animation (original behavior, 'surface'),
2446          - 2D image animation using imshow (new, 'imshow') which is faster and
2447            often clearer for large grids.
2448    
2449        Parameters
2450        ----------
2451        component : str, optional, one of {'real', 'imag', 'abs', 'angle'}
2452            Which component of the complex field to visualize:
2453              - 'real'  : Re(u)
2454              - 'imag'  : Im(u)
2455              - 'abs'   : |u|
2456              - 'angle' : arg(u)
2457            Default is 'abs'.
2458    
2459        overlay : str or None, optional, one of {'contour', 'front', None}
2460            For 2D modes only. If None, no overlay is drawn.
2461              - 'contour' : draw contour lines on top (or beneath for 3D surface)
2462              - 'front'   : detect and mark wavefronts using gradient maxima
2463            Default is 'contour'.
2464    
2465        mode : str, optional, one of {'surface', 'imshow'}
2466            2D rendering mode. 'surface' keeps the original 3D surface plot.
2467            'imshow' draws a 2D raster (faster, often more readable).
2468            Default is 'surface' for backward compatibility.
2469    
2470        Returns
2471        -------
2472        FuncAnimation
2473            A Matplotlib `FuncAnimation` instance (you can display it in a notebook
2474            or save it to file).
2475    
2476        Notes
2477        -----
2478        - The method uses the same time-mapping logic as before (linear sampling of
2479          stored frames to animation frames).
2480        - For 'angle' the color scale is fixed between -π and π.
2481        - For other components, color scaling is by default dynamically adapted per
2482          frame in 'imshow' mode (this avoids extreme clipping if amplitudes vary).
2483        - Overlays are updated cleanly: previous contour/scatter artists are removed
2484          before drawing the next frame to avoid memory/visual accumulation.
2485        - Animation interval is 50 ms per frame (unchanged).
2486        """
2487        def _get_component(u):
2488            if component == 'real':
2489                return np.real(u)
2490            elif component == 'imag':
2491                return np.imag(u)
2492            elif component == 'abs':
2493                return np.abs(u)
2494            elif component == 'angle':
2495                return np.angle(u)
2496            else:
2497                raise ValueError("Invalid component: choose 'real','imag','abs' or 'angle'")
2498    
2499        print("\n*********************")
2500        print("* Solution plotting *")
2501        print("*********************\n")
2502    
2503        # === Calculate time vector of stored frames ===
2504        save_interval = max(1, self.Nt // self.n_frames)
2505        frame_times = np.arange(0, self.Lt + self.dt, save_interval * self.dt)
2506    
2507        # === Target times for animation ===
2508        target_times = np.linspace(0, self.Lt, self.n_frames // 2)
2509    
2510        # Map target times to nearest frame indices
2511        frame_indices = [np.argmin(np.abs(frame_times - t)) for t in target_times]
2512    
2513        # -------------------------
2514        # 1D case (unchanged logic)
2515        # -------------------------
2516        if self.dim == 1:
2517            fig, ax = plt.subplots()
2518            initial = get_component(self.frames[0])
2519            line, = ax.plot(self.X, np.real(initial) if np.iscomplexobj(initial) else initial)
2520            ax.set_ylim(np.min(initial), np.max(initial))
2521            ax.set_xlabel('x')
2522            ax.set_ylabel(f'{component} of u')
2523            ax.set_title('Initial condition')
2524            plt.tight_layout()
2525    
2526            def _update_1d(frame_number):
2527                frame = frame_indices[frame_number]
2528                ydata = get_component(self.frames[frame])
2529                ydata_real = np.real(ydata) if np.iscomplexobj(ydata) else ydata
2530                line.set_ydata(ydata_real)
2531                ax.set_ylim(np.min(ydata_real), np.max(ydata_real))
2532                current_time = target_times[frame_number]
2533                ax.set_title(f't = {current_time:.2f}')
2534                return (line,)
2535    
2536            ani = FuncAnimation(fig, update_1d, frames=len(target_times), interval=50)
2537            return ani
2538    
2539        # -------------------------
2540        # 2D case
2541        # -------------------------
2542        # Validate mode
2543        if mode not in ('surface', 'imshow'):
2544            raise ValueError("Invalid mode: choose 'surface' or 'imshow'")
2545    
2546        # Common data
2547        data0 = get_component(self.frames[0])
2548    
2549        if mode == 'surface':
2550            # original surface behavior, but ensure clean updates
2551            fig = plt.figure(figsize=(14, 8))
2552            ax = fig.add_subplot(111, projection='3d')
2553            ax.set_xlabel('x')
2554            ax.set_ylabel('y')
2555            ax.set_zlabel(f'{component.title()} of u')
2556            ax.zaxis.labelpad = 0
2557            ax.set_title('Initial condition')
2558    
2559            surf = ax.plot_surface(self.X, self.Y, data0, cmap='viridis')
2560            plt.tight_layout()
2561    
2562            def _update_surface(frame_number):
2563                frame = frame_indices[frame_number]
2564                current_data = get_component(self.frames[frame])
2565                z_offset = np.max(current_data) + 0.05 * (np.max(current_data) - np.min(current_data))
2566    
2567                ax.clear()
2568                surf_obj = ax.plot_surface(self.X, self.Y, current_data,
2569                                           cmap='viridis',
2570                                           vmin=(-np.pi if component == 'angle' else None),
2571                                           vmax=(np.pi if component == 'angle' else None))
2572                # overlays
2573                if overlay == 'contour':
2574                    # place contours slightly below the surface (use offset)
2575                    try:
2576                        ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool', offset=z_offset)
2577                    except Exception:
2578                        # fallback: simple contour without offset if not supported
2579                        ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool')
2580    
2581                elif overlay == 'front':
2582                    dx = self.x_grid[1] - self.x_grid[0]
2583                    dy = self.y_grid[1] - self.y_grid[0]
2584                    # numpy.gradient: axis0 -> y spacing, axis1 -> x spacing
2585                    du_dy, du_dx = np.gradient(current_data, dy, dx)
2586                    grad_norm = np.sqrt(du_dx**2 + du_dy**2)
2587                    local_max = (grad_norm == maximum_filter(grad_norm, size=5))
2588                    if np.max(grad_norm) > 0:
2589                        normalized = grad_norm[local_max] / np.max(grad_norm)
2590                    else:
2591                        normalized = np.zeros(np.count_nonzero(local_max))
2592                    colors = cm.plasma(normalized)
2593                    ax.scatter(self.X[local_max], self.Y[local_max],
2594                               z_offset * np.ones_like(self.X[local_max]),
2595                               color=colors, s=10, alpha=0.8)
2596    
2597                ax.set_xlabel('x')
2598                ax.set_ylabel('y')
2599                ax.set_zlabel(f'{component.title()} of u')
2600                current_time = target_times[frame_number]
2601                ax.set_title(f'Solution at t = {current_time:.2f}')
2602                return (surf_obj,)
2603    
2604            ani = FuncAnimation(fig, update_surface, frames=len(target_times), interval=50)
2605            return ani
2606    
2607        else:  # mode == 'imshow'
2608            fig, ax = plt.subplots(figsize=(7, 6))
2609            ax.set_xlabel('x')
2610            ax.set_ylabel('y')
2611            ax.set_title('Initial condition')
2612    
2613            # extent uses physical coordinates so axes show real x/y values
2614            extent = [self.x_grid[0], self.x_grid[-1], self.y_grid[0], self.y_grid[-1]]
2615    
2616            if component == 'angle':
2617                vmin, vmax = -np.pi, np.pi
2618                cmap = 'twilight'
2619            else:
2620                vmin, vmax = np.min(data0), np.max(data0)
2621                cmap = 'viridis'
2622    
2623            im = ax.imshow(data0, extent=extent, origin='lower', cmap=cmap,
2624                           vmin=vmin, vmax=vmax, aspect='auto')
2625            cbar = fig.colorbar(im, ax=ax)
2626            cbar.set_label(f"{component} of u")
2627            plt.tight_layout()
2628    
2629            # containers for dynamic overlay artists (stored on function object)
2630            # update_im.contour_art and update_im.scatter_art will be created dynamically
2631    
2632            def _update_im(frame_number):
2633                frame = frame_indices[frame_number]
2634                current_data = get_component(self.frames[frame])
2635    
2636                # update raster
2637                im.set_data(current_data)
2638                if component != 'angle':
2639                    # dynamic per-frame scaling (keeps contrast when amplitude varies)
2640                    cmin = np.nanmin(current_data)
2641                    cmax = np.nanmax(current_data)
2642                    # avoid identical vmin==vmax
2643                    if cmax > cmin:
2644                        im.set_clim(cmin, cmax)
2645    
2646                # remove previous contour if exists
2647                if overlay == 'contour':
2648                    if hasattr(update_im, 'contour_art') and update_im.contour_art is not None:
2649                        for coll in update_im.contour_art.collections:
2650                            try:
2651                                coll.remove()
2652                            except Exception:
2653                                pass
2654                        update_im.contour_art = None
2655                    # draw new contours (use meshgrid coords)
2656                    try:
2657                        update_im.contour_art = ax.contour(self.X, self.Y, current_data, levels=10, cmap='cool')
2658                    except Exception:
2659                        # fallback: contour with axis coordinates (x_grid, y_grid)
2660                        Xc, Yc = np.meshgrid(self.x_grid, self.y_grid)
2661                        update_im.contour_art = ax.contour(Xc, Yc, current_data, levels=10, cmap='cool')
2662    
2663                # remove previous scatter if exists
2664                if overlay == 'front':
2665                    if hasattr(update_im, 'scatter_art') and update_im.scatter_art is not None:
2666                        try:
2667                            update_im.scatter_art.remove()
2668                        except Exception:
2669                            pass
2670                        update_im.scatter_art = None
2671    
2672                    dx = self.x_grid[1] - self.x_grid[0]
2673                    dy = self.y_grid[1] - self.y_grid[0]
2674                    du_dy, du_dx = np.gradient(current_data, dy, dx)
2675                    grad_norm = np.sqrt(du_dx**2 + du_dy**2)
2676                    local_max = (grad_norm == maximum_filter(grad_norm, size=5))
2677                    if np.max(grad_norm) > 0:
2678                        normalized = grad_norm[local_max] / np.max(grad_norm)
2679                    else:
2680                        normalized = np.zeros(np.count_nonzero(local_max))
2681                    colors = cm.plasma(normalized)
2682                    update_im.scatter_art = ax.scatter(self.X[local_max], self.Y[local_max],
2683                                                       c=colors, s=10, alpha=0.8)
2684    
2685                current_time = target_times[frame_number]
2686                ax.set_title(f'Solution at t = {current_time:.2f}')
2687                # return main image plus any overlay artists present so Matplotlib can redraw them
2688                artists = [im]
2689                if overlay == 'contour' and hasattr(update_im, 'contour_art') and update_im.contour_art is not None:
2690                    artists.extend(update_im.contour_art.collections)
2691                if overlay == 'front' and hasattr(update_im, 'scatter_art') and update_im.scatter_art is not None:
2692                    artists.append(update_im.scatter_art)
2693                return tuple(artists)
2694    
2695            ani = FuncAnimation(fig, update_im, frames=len(target_times), interval=50)
2696            return ani

Create an animated plot of the solution evolution over time.

This method generates a dynamic visualization of the stored solution frames self.frames. It supports:

  • 1D line animation (unchanged),
  • 2D surface animation (original behavior, 'surface'),
  • 2D image animation using imshow (new, 'imshow') which is faster and often clearer for large grids.

Parameters

component : str, optional, one of {'real', 'imag', 'abs', 'angle'} Which component of the complex field to visualize: - 'real' : Re(u) - 'imag' : Im(u) - 'abs' : |u| - 'angle' : arg(u) Default is 'abs'.

overlay : str or None, optional, one of {'contour', 'front', None} For 2D modes only. If None, no overlay is drawn. - 'contour' : draw contour lines on top (or beneath for 3D surface) - 'front' : detect and mark wavefronts using gradient maxima Default is 'contour'.

mode : str, optional, one of {'surface', 'imshow'} 2D rendering mode. 'surface' keeps the original 3D surface plot. 'imshow' draws a 2D raster (faster, often more readable). Default is 'surface' for backward compatibility.

Returns

FuncAnimation A Matplotlib FuncAnimation instance (you can display it in a notebook or save it to file).

Notes

  • The method uses the same time-mapping logic as before (linear sampling of stored frames to animation frames).
  • For 'angle' the color scale is fixed between -π and π.
  • For other components, color scaling is by default dynamically adapted per frame in 'imshow' mode (this avoids extreme clipping if amplitudes vary).
  • Overlays are updated cleanly: previous contour/scatter artists are removed before drawing the next frame to avoid memory/visual accumulation.
  • Animation interval is 50 ms per frame (unchanged).
def test( self, u_exact, t_eval=None, norm='relative', threshold=0.01, component='real'):
2698    def test(self, u_exact, t_eval=None, norm='relative', threshold=1e-2, component='real'):
2699        """
2700        Test the solver against an exact solution.
2701
2702        This method quantitatively compares the numerical solution with a provided exact solution 
2703        at a specified time using either relative or absolute error norms. It supports both 
2704        stationary and time-dependent problems in 1D and 2D. If enabled, it also generates plots 
2705        of the solution, exact solution, and pointwise error.
2706
2707        Parameters
2708        ----------
2709        u_exact : callable
2710            Exact solution function taking spatial coordinates and optionally time as arguments.
2711        t_eval : float, optional
2712            Time at which to compare solutions. For non-stationary problems, defaults to final time Lt.
2713            Ignored for stationary problems.
2714        norm : str {'relative', 'absolute'}
2715            Type of error norm used in comparison.
2716        threshold : float
2717            Acceptable error threshold; raises an assertion if exceeded.
2718        plot : bool
2719            Whether to display visual comparison plots (default: True).
2720        component : str {'real', 'imag', 'abs'}
2721            Component of the solution to compare and visualize.
2722
2723        Raises
2724        ------
2725        ValueError
2726            If unsupported dimension is encountered or requested evaluation time exceeds simulation duration.
2727        AssertionError
2728            If computed error exceeds the given threshold.
2729
2730        Prints
2731        ------
2732        - Information about the closest available frame to the requested evaluation time.
2733        - Computed error value and comparison to threshold.
2734
2735        Notes
2736        -----
2737        - For time-dependent problems, the solution is extracted from precomputed frames.
2738        - Plots are adapted to spatial dimension: line plots for 1D, image plots for 2D.
2739        - The method ensures consistent handling of real, imaginary, and magnitude components.
2740        """
2741        if self.is_stationary:
2742            print("Testing a stationary solution.")
2743            u_num = self.u
2744    
2745            # Compute exact solution
2746            if self.dim == 1:
2747                u_ex = u_exact(self.X)
2748            elif self.dim == 2:
2749                u_ex = u_exact(self.X, self.Y)
2750            else:
2751                raise ValueError("Unsupported dimension.")
2752            actual_t = None
2753        else:
2754            if t_eval is None:
2755                t_eval = self.Lt
2756    
2757            save_interval = max(1, self.Nt // self.n_frames)
2758            frame_times = np.arange(0, self.Lt + self.dt, save_interval * self.dt)
2759            frame_index = np.argmin(np.abs(frame_times - t_eval))
2760            actual_t = frame_times[frame_index]
2761            print(f"Closest available time to t_eval={t_eval}: {actual_t}")
2762    
2763            if frame_index >= len(self.frames):
2764                raise ValueError(f"Time t = {t_eval} exceeds simulation duration.")
2765    
2766            u_num = self.frames[frame_index]
2767    
2768            # Compute exact solution at the actual time
2769            if self.dim == 1:
2770                u_ex = u_exact(self.X, actual_t)
2771            elif self.dim == 2:
2772                u_ex = u_exact(self.X, self.Y, actual_t)
2773            else:
2774                raise ValueError("Unsupported dimension.")
2775    
2776        # Select component
2777        if component == 'real':
2778            diff = np.real(u_num) - np.real(u_ex)
2779            ref = np.real(u_ex)
2780        elif component == 'imag':
2781            diff = np.imag(u_num) - np.imag(u_ex)
2782            ref = np.imag(u_ex)
2783        elif component == 'abs':
2784            diff = np.abs(u_num) - np.abs(u_ex)
2785            ref = np.abs(u_ex)
2786        else:
2787            raise ValueError("Invalid component.")
2788    
2789        # Compute error
2790        if norm == 'relative':
2791            error = np.linalg.norm(diff) / np.linalg.norm(ref)
2792        elif norm == 'absolute':
2793            error = np.linalg.norm(diff)
2794        else:
2795            raise ValueError("Unknown norm type.")
2796    
2797        label_time = f"t = {actual_t}" if actual_t is not None else ""
2798        print(f"Test error {label_time}: {error:.3e}")
2799        assert error < threshold, f"Error too large {label_time}: {error:.3e}"
2800    
2801        # Plot
2802        if self.plot:
2803            if self.dim == 1:
2804                plt.figure(figsize=(12, 6))
2805                plt.subplot(2, 1, 1)
2806                plt.plot(self.X, np.real(u_num), label='Numerical')
2807                plt.plot(self.X, np.real(u_ex), '--', label='Exact')
2808                plt.title(f'Solution {label_time}, error = {error:.2e}')
2809                plt.legend()
2810                plt.grid()
2811    
2812                plt.subplot(2, 1, 2)
2813                plt.plot(self.X, np.abs(diff), color='red')
2814                plt.title('Absolute Error')
2815                plt.grid()
2816                plt.tight_layout()
2817                plt.show()
2818            else:
2819                extent = [-self.Lx/2, self.Lx/2, -self.Ly/2, self.Ly/2]
2820                plt.figure(figsize=(15, 5))
2821                plt.subplot(1, 3, 1)
2822                plt.title("Numerical Solution")
2823                plt.imshow(np.abs(u_num), origin='lower', extent=extent, cmap='viridis')
2824                plt.colorbar()
2825    
2826                plt.subplot(1, 3, 2)
2827                plt.title("Exact Solution")
2828                plt.imshow(np.abs(u_ex), origin='lower', extent=extent, cmap='viridis')
2829                plt.colorbar()
2830    
2831                plt.subplot(1, 3, 3)
2832                plt.title(f"Error (Norm = {error:.2e})")
2833                plt.imshow(np.abs(diff), origin='lower', extent=extent, cmap='inferno')
2834                plt.colorbar()
2835                plt.tight_layout()
2836                plt.show()
2837
2838        return error

Test the solver against an exact solution.

This method quantitatively compares the numerical solution with a provided exact solution at a specified time using either relative or absolute error norms. It supports both stationary and time-dependent problems in 1D and 2D. If enabled, it also generates plots of the solution, exact solution, and pointwise error.

Parameters

u_exact : callable Exact solution function taking spatial coordinates and optionally time as arguments. t_eval : float, optional Time at which to compare solutions. For non-stationary problems, defaults to final time Lt. Ignored for stationary problems. norm : str {'relative', 'absolute'} Type of error norm used in comparison. threshold : float Acceptable error threshold; raises an assertion if exceeded. plot : bool Whether to display visual comparison plots (default: True). component : str {'real', 'imag', 'abs'} Component of the solution to compare and visualize.

Raises

ValueError If unsupported dimension is encountered or requested evaluation time exceeds simulation duration. AssertionError If computed error exceeds the given threshold.

Prints

  • Information about the closest available frame to the requested evaluation time.
  • Computed error value and comparison to threshold.

Notes

  • For time-dependent problems, the solution is extracted from precomputed frames.
  • Plots are adapted to spatial dimension: line plots for 1D, image plots for 2D.
  • The method ensures consistent handling of real, imaginary, and magnitude components.
class LagrangianHamiltonianConverter:
 37class LagrangianHamiltonianConverter:
 38    """
 39    Bidirectional converter between Lagrangian and Hamiltonian (Legendre transform),
 40    with optional Legendre–Fenchel (convex conjugate) support and robust numeric fallback.
 41
 42    Main API:
 43      L_to_H(L_expr, coords, u, p_vars, return_symbol_only=False, force=False,
 44             method="legendre", fenchel_opts=None)
 45
 46        - method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
 47        - If method == "fenchel_numeric" returns (H_repr, xi_vars, numeric_callable)
 48          otherwise returns (H_expr, xi_vars)
 49    """
 50
 51    _numeric_cache = {}
 52
 53    # --------------------
 54    # Utilities
 55    # --------------------
 56    @staticmethod
 57    def _is_quadratic_in_p(L_expr, p_vars):
 58        """
 59        Robust test: returns True only if L_expr is polynomial of degree ≤ 2 in each p_var.
 60        Falls back to False for non-polynomial expressions (Abs, sqrt, etc.).
 61        """
 62        for p in p_vars:
 63            # Quick test: is L polynomial in p?
 64            if not L_expr.is_polynomial(p):
 65                return False
 66            try:
 67                deg = sp.degree(L_expr, p)
 68            except Exception:
 69                return False
 70            if deg is None or deg > 2:
 71                return False
 72        return True
 73
 74    @staticmethod
 75    def _quadratic_legendre(L_expr, p_vars, xi_vars):
 76        """
 77        Analytic Legendre transform for quadratic L: L = 1/2 p^T A p + b^T p + c
 78        Returns (H_expr, sol_map) and raises ValueError if Hessian singular.
 79        """
 80        A = Matrix([[sp.diff(sp.diff(L_expr, p_i), p_j) for p_j in p_vars] for p_i in p_vars])
 81        grad = Matrix([sp.diff(L_expr, p) for p in p_vars])
 82        try:
 83            A_inv = A.inv()
 84        except Exception:
 85            raise ValueError("Quadratic analytic path: Hessian A is singular (non-invertible).")
 86        subs_zero = {p: 0 for p in p_vars}
 87        b_vec = grad.subs(subs_zero)
 88        xi_vec = Matrix(xi_vars)
 89        p_solution_vec = A_inv * (xi_vec - b_vec)
 90        sol = {p_vars[i]: sp.simplify(p_solution_vec[i]) for i in range(len(p_vars))}
 91        H_expr = sum(xi_vars[i] * sol[p_vars[i]] for i in range(len(p_vars))) - sp.simplify(L_expr.subs(sol))
 92        return sp.simplify(H_expr), sol
 93
 94    # ----------------------------
 95    # Numeric Legendre-Fenchel helpers
 96    # ----------------------------
 97    @staticmethod
 98    def _legendre_fenchel_1d_numeric_callable(L_func, p_bounds=(-10.0, 10.0), n_grid=2001, mode="auto",
 99                                             scipy_multistart=5):
100        """
101        Return a callable H_numeric(xi) = sup_p (xi*p - L(p)) for 1D L_func(p).
102        - L_func: callable p -> L(p)
103        - mode: "auto" | "scipy" | "grid"
104        """
105        pmin, pmax = float(p_bounds[0]), float(p_bounds[1])
106
107        def _compute_by_grid(xi):
108            grid = _np.linspace(pmin, pmax, int(n_grid))
109            Lvals = _np.array([float(L_func(p)) for p in grid], dtype=float)
110            S = xi * grid - Lvals
111            idx = int(_np.argmax(S))
112            return float(S[idx]), float(grid[idx])
113
114        def _compute_by_scipy(xi):
115            if not _HAS_SCIPY:
116                return _compute_by_grid(xi)
117
118            def negS(p):
119                p0 = float(p[0])
120                return -(xi * p0 - float(L_func(p0)))
121
122            best_val = -_math.inf
123            best_p = None
124            inits = _np.linspace(pmin, pmax, max(3, int(scipy_multistart)))
125            for x0 in inits:
126                try:
127                    res = _optimize.minimize(negS, x0=[float(x0)], bounds=[(pmin, pmax)], method="L-BFGS-B")
128                    if res.success:
129                        pstar = float(res.x[0])
130                        sval = float(xi * pstar - float(L_func(pstar)))
131                        if sval > best_val:
132                            best_val = sval
133                            best_p = pstar
134                except Exception:
135                    continue
136            if best_p is None:
137                return _compute_by_grid(xi)
138            return best_val, best_p
139
140        compute = _compute_by_scipy if (_HAS_SCIPY and mode != "grid") else _compute_by_grid
141
142        def H_numeric(xi_in):
143            xi_arr = _np.atleast_1d(xi_in).astype(float)
144            out = _np.empty_like(xi_arr, dtype=float)
145            for i, xi in enumerate(xi_arr):
146                val, _ = compute(float(xi))
147                out[i] = val
148            if _np.isscalar(xi_in):
149                return float(out[0])
150            return out
151
152        return H_numeric
153
154    @staticmethod
155    def _legendre_fenchel_nd_numeric_callable(L_func, dim, p_bounds, n_grid_per_dim=41, mode="auto",
156                                              scipy_multistart=10, multistart_restarts=8):
157        """
158        Return callable H_numeric(xi_vector) approximating sup_p (xi·p - L(p)) for dim>=2.
159        - L_func: callable p_vector -> L(p)
160        - p_bounds: tuple/list of per-dimension bounds
161        """
162        pmin_list, pmax_list = p_bounds
163        pmin = [float(v) for v in pmin_list]
164        pmax = [float(v) for v in pmax_list]
165
166        def compute_by_grid(xi_vec):
167            import itertools
168            grids = [_np.linspace(pmin[d], pmax[d], int(n_grid_per_dim)) for d in range(dim)]
169            best = -_math.inf
170            best_p = None
171            for pt in itertools.product(*grids):
172                pt_arr = _np.array(pt, dtype=float)
173                sval = float(_np.dot(xi_vec, pt_arr) - L_func(pt_arr))
174                if sval > best:
175                    best = sval
176                    best_p = pt_arr
177            return best, best_p
178
179        def compute_by_scipy(xi_vec):
180            if not _HAS_SCIPY:
181                return compute_by_grid(xi_vec)
182
183            def negS(p):
184                p = _np.asarray(p, dtype=float)
185                return - (float(_np.dot(xi_vec, p)) - float(L_func(p)))
186
187            best_val = -_math.inf
188            best_p = None
189            center = _np.array([(pmin[d] + pmax[d]) / 2.0 for d in range(dim)], dtype=float)
190            rng = _np.random.default_rng(123456)
191            inits = [center]
192            for k in range(multistart_restarts):
193                r = rng.random(dim)
194                start = _np.array([pmin[d] + r[d] * (pmax[d] - pmin[d]) for d in range(dim)], dtype=float)
195                inits.append(start)
196            for x0 in inits:
197                try:
198                    res = _optimize.minimize(negS, x0=x0, bounds=tuple((pmin[d], pmax[d]) for d in range(dim)),
199                                             method="L-BFGS-B")
200                    if res.success:
201                        pstar = _np.asarray(res.x, dtype=float)
202                        sval = float(_np.dot(xi_vec, pstar) - L_func(pstar))
203                        if sval > best_val:
204                            best_val = sval
205                            best_p = pstar
206                except Exception:
207                    continue
208            if best_p is None:
209                return compute_by_grid(xi_vec)
210            return best_val, best_p
211
212        compute = compute_by_scipy if (_HAS_SCIPY and mode != "grid") else compute_by_grid
213
214        def H_numeric(xi_in):
215            xi_arr = _np.atleast_2d(xi_in).astype(float)
216            if xi_arr.shape[-1] != dim:
217                xi_arr = xi_arr.reshape(-1, dim)
218            out = _np.empty((xi_arr.shape[0],), dtype=float)
219            for i, xivec in enumerate(xi_arr):
220                val, _ = compute(xivec)
221                out[i] = val
222            if out.shape[0] == 1:
223                return float(out[0])
224            return out
225
226        return H_numeric
227
228    # ----------------------------
229    # Main methods
230    # ----------------------------
231    @staticmethod
232    def L_to_H(L_expr, coords, u, p_vars, return_symbol_only=False, force=False,
233               method="legendre", fenchel_opts=None):
234        """
235        Convert L(x,u,p) -> H(x,u,xi) with options for generalized Legendre (Fenchel).
236
237        Parameters:
238          - method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
239          - fenchel_opts: dict with options for numeric fenchel
240        """
241        dim = len(coords)
242        if dim == 1:
243            xi_vars = (sp.Symbol('xi', real=True),)
244        elif dim == 2:
245            xi_vars = (sp.Symbol('xi', real=True), sp.Symbol('eta', real=True))
246        else:
247            raise ValueError("Only 1D and 2D dimensions are supported.")
248
249        # Quadratic fast-path (symbolic)
250        if method in ("legendre", "fenchel_symbolic") and LagrangianHamiltonianConverter._is_quadratic_in_p(L_expr, p_vars):
251            try:
252                H_expr, sol = LagrangianHamiltonianConverter._quadratic_legendre(L_expr, p_vars, xi_vars)
253                if return_symbol_only:
254                    H_expr = H_expr.subs(u, 0)
255                return H_expr, xi_vars
256            except Exception:
257                if not force and method == "legendre":
258                    raise
259
260        # CLASSICAL LEGENDRE
261        if method == "legendre":
262            H_p = None
263            try:
264                H_p = sp.hessian(L_expr, p_vars)
265                det_H = sp.simplify(H_p.det())
266            except Exception:
267                det_H = None
268
269            if det_H is not None and det_H == 0 and not force:
270                raise ValueError("Legendre transform not invertible: Hessian singular. Use force=True or Fenchel method.")
271            if det_H is None and not force:
272                raise ValueError("Unable to verify Hessian determinant symbolically. Use force=True to attempt solve().")
273
274            eqs = [sp.Eq(sp.diff(L_expr, p_vars[i]), xi_vars[i]) for i in range(dim)]
275            sol_list = sp.solve(eqs, p_vars, dict=True)
276            if not sol_list:
277                if not force:
278                    raise ValueError("Unable to solve symbolic Legendre relations. Use force=True or Fenchel fallback.")
279            if sol_list:
280                sol = sol_list[0]
281                if isinstance(sol, tuple) and len(sol) == len(p_vars):
282                    sol = {p_vars[i]: sol[i] for i in range(len(p_vars))}
283                H_expr = sum(xi_vars[i]*sol[p_vars[i]] for i in range(dim)) - L_expr.subs(sol)
284                H_expr = sp.simplify(H_expr)
285                if return_symbol_only:
286                    H_expr = H_expr.subs(u, 0)
287                return H_expr, xi_vars
288            raise ValueError("Legendre inversion failed even with solve().")
289
290        # FENCHEL: symbolic attempt
291        # -----------------------------------------------------
292        #  Prevent symbolic Fenchel when L is non-differentiable
293        # -----------------------------------------------------
294        if method == "fenchel_symbolic":
295            if L_expr.has(sp.Abs) or L_expr.has(sp.sign) or any(
296                sp.diff(L_expr, p).has(sp.sign, sp.Abs) for p in p_vars
297            ):
298                raise ValueError(
299                    "Symbolic Fenchel not possible for nonsmooth L (Abs, sign). "
300                    "Use method='fenchel_numeric' instead."
301                )
302
303        if method == "fenchel_symbolic":
304            eqs = [sp.Eq(sp.diff(L_expr, p_vars[i]), xi_vars[i]) for i in range(dim)]
305            sol_list = sp.solve(eqs, p_vars, dict=True)
306            if sol_list:
307                candidates = []
308                for sol in sol_list:
309                    if isinstance(sol, tuple) and len(sol) == len(p_vars):
310                        sol = {p_vars[i]: sol[i] for i in range(len(p_vars))}
311                    S_expr = sum(xi_vars[i] * sol[p_vars[i]] for i in range(dim)) - L_expr.subs(sol)
312                    candidates.append(sp.simplify(S_expr))
313                H_candidates = sp.simplify(sp.Max(*candidates)) if len(candidates) > 1 else candidates[0]
314                if return_symbol_only:
315                    H_candidates = H_candidates.subs(u, 0)
316                return H_candidates, xi_vars
317            raise ValueError("Symbolic Fenchel conjugate not found; use method='fenchel_numeric' for numeric computation.")
318
319        # FENCHEL: numeric path
320        if method == "fenchel_numeric":
321            if fenchel_opts is None:
322                fenchel_opts = {}
323            if dim == 1:
324                p_bounds = fenchel_opts.get("p_bounds", (-10.0, 10.0))
325                n_grid = int(fenchel_opts.get("n_grid", 2001))
326                mode = fenchel_opts.get("mode", "auto")
327                scipy_multistart = int(fenchel_opts.get("scipy_multistart", 8))
328
329                # Build numeric L_func (try lambdify)
330                try:
331                    f_lamb = sp.lambdify((p_vars[0],), L_expr, "numpy")
332                    def L_func_scalar(p):
333                        return float(f_lamb(p))
334                except Exception:
335                    try:
336                        f_lamb = sp.lambdify(p_vars[0], L_expr, "numpy")
337                        def L_func_scalar(p):
338                            return float(f_lamb(p))
339                    except Exception:
340                        def L_func_scalar(p):
341                            return float(sp.N(L_expr.subs({p_vars[0]: p})))
342
343                H_numeric = LagrangianHamiltonianConverter._legendre_fenchel_1d_numeric_callable(
344                    L_func_scalar, p_bounds=p_bounds, n_grid=n_grid, mode=mode,
345                    scipy_multistart=scipy_multistart
346                )
347                H_func = sp.Function("H_numeric")
348                H_repr = H_func(xi_vars[0])
349                LagrangianHamiltonianConverter._numeric_cache[id(H_repr)] = H_numeric
350                return H_repr, xi_vars, H_numeric
351
352            else:
353                # dim == 2
354                p_bounds = fenchel_opts.get("p_bounds", [(-10.0, 10.0), (-10.0, 10.0)])
355                n_grid_per_dim = int(fenchel_opts.get("n_grid_per_dim", 41))
356                mode = fenchel_opts.get("mode", "auto")
357                scipy_multistart = int(fenchel_opts.get("scipy_multistart", 20))
358                multistart_restarts = int(fenchel_opts.get("multistart_restarts", 8))
359
360                f_lamb = None
361                try:
362                    f_lamb = sp.lambdify((p_vars[0], p_vars[1]), L_expr, "numpy")
363                    def L_func_nd(p):
364                        return float(f_lamb(float(p[0]), float(p[1])))
365                except Exception:
366                    try:
367                        f_lamb = sp.lambdify((p_vars,), L_expr, "numpy")
368                        def L_func_nd(p):
369                            return float(f_lamb(tuple(float(v) for v in p)))
370                    except Exception:
371                        def L_func_nd(p):
372                            subs_map = {p_vars[i]: float(p[i]) for i in range(2)}
373                            return float(sp.N(L_expr.subs(subs_map)))
374
375                H_numeric = LagrangianHamiltonianConverter._legendre_fenchel_nd_numeric_callable(
376                    L_func_nd, dim=2, p_bounds=(p_bounds[0], p_bounds[1]),
377                    n_grid_per_dim=n_grid_per_dim, mode=mode,
378                    scipy_multistart=scipy_multistart, multistart_restarts=multistart_restarts
379                )
380                H_func = sp.Function("H_numeric")
381                H_repr = H_func(*xi_vars)
382                LagrangianHamiltonianConverter._numeric_cache[id(H_repr)] = H_numeric
383                return H_repr, xi_vars, H_numeric
384
385        raise ValueError("Unknown method '{}'. Choose 'legendre', 'fenchel_symbolic' or 'fenchel_numeric'.".format(method))
386
387    @staticmethod
388    def H_to_L(H_expr, coords, u, xi_vars, force=False):
389        """
390        Inverse Legendre (classical). Does not attempt Fenchel inverse.
391        """
392        dim = len(coords)
393        if dim == 1:
394            p_vars = (sp.Symbol('p', real=True),)
395        elif dim == 2:
396            p_vars = (sp.Symbol('p_x', real=True), sp.Symbol('p_y', real=True))
397        else:
398            raise ValueError("Only 1D and 2D are supported.")
399
400        eqs = [sp.Eq(sp.diff(H_expr, xi_vars[i]), p_vars[i]) for i in range(dim)]
401        sol = sp.solve(eqs, xi_vars, dict=True)
402        if not sol:
403            if not force:
404                raise ValueError("Unable to symbolically solve p = ∂H/∂ξ for ξ. Use force=True.")
405            sol = sp.solve(eqs, xi_vars)
406        if not sol:
407            raise ValueError("Inverse Legendre transform failed; cannot find ξ(p).")
408        sol = sol[0] if isinstance(sol, list) else sol
409        if isinstance(sol, tuple) and len(sol) == len(xi_vars):
410            sol = {xi_vars[i]: sol[i] for i in range(len(xi_vars))}
411        if not isinstance(sol, dict):
412            if isinstance(sol, list) and sol and isinstance(sol[0], dict):
413                sol = sol[0]
414            else:
415                raise ValueError("Unexpected output from solve(); cannot construct ξ(p).")
416        L_expr = sum(sol[xi_vars[i]] * p_vars[i] for i in range(dim)) - H_expr.subs(sol)
417        return sp.simplify(L_expr), p_vars

Bidirectional converter between Lagrangian and Hamiltonian (Legendre transform), with optional Legendre–Fenchel (convex conjugate) support and robust numeric fallback.

Main API: L_to_H(L_expr, coords, u, p_vars, return_symbol_only=False, force=False, method="legendre", fenchel_opts=None)

- method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
- If method == "fenchel_numeric" returns (H_repr, xi_vars, numeric_callable)
  otherwise returns (H_expr, xi_vars)
@staticmethod
def L_to_H( L_expr, coords, u, p_vars, return_symbol_only=False, force=False, method='legendre', fenchel_opts=None):
231    @staticmethod
232    def L_to_H(L_expr, coords, u, p_vars, return_symbol_only=False, force=False,
233               method="legendre", fenchel_opts=None):
234        """
235        Convert L(x,u,p) -> H(x,u,xi) with options for generalized Legendre (Fenchel).
236
237        Parameters:
238          - method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
239          - fenchel_opts: dict with options for numeric fenchel
240        """
241        dim = len(coords)
242        if dim == 1:
243            xi_vars = (sp.Symbol('xi', real=True),)
244        elif dim == 2:
245            xi_vars = (sp.Symbol('xi', real=True), sp.Symbol('eta', real=True))
246        else:
247            raise ValueError("Only 1D and 2D dimensions are supported.")
248
249        # Quadratic fast-path (symbolic)
250        if method in ("legendre", "fenchel_symbolic") and LagrangianHamiltonianConverter._is_quadratic_in_p(L_expr, p_vars):
251            try:
252                H_expr, sol = LagrangianHamiltonianConverter._quadratic_legendre(L_expr, p_vars, xi_vars)
253                if return_symbol_only:
254                    H_expr = H_expr.subs(u, 0)
255                return H_expr, xi_vars
256            except Exception:
257                if not force and method == "legendre":
258                    raise
259
260        # CLASSICAL LEGENDRE
261        if method == "legendre":
262            H_p = None
263            try:
264                H_p = sp.hessian(L_expr, p_vars)
265                det_H = sp.simplify(H_p.det())
266            except Exception:
267                det_H = None
268
269            if det_H is not None and det_H == 0 and not force:
270                raise ValueError("Legendre transform not invertible: Hessian singular. Use force=True or Fenchel method.")
271            if det_H is None and not force:
272                raise ValueError("Unable to verify Hessian determinant symbolically. Use force=True to attempt solve().")
273
274            eqs = [sp.Eq(sp.diff(L_expr, p_vars[i]), xi_vars[i]) for i in range(dim)]
275            sol_list = sp.solve(eqs, p_vars, dict=True)
276            if not sol_list:
277                if not force:
278                    raise ValueError("Unable to solve symbolic Legendre relations. Use force=True or Fenchel fallback.")
279            if sol_list:
280                sol = sol_list[0]
281                if isinstance(sol, tuple) and len(sol) == len(p_vars):
282                    sol = {p_vars[i]: sol[i] for i in range(len(p_vars))}
283                H_expr = sum(xi_vars[i]*sol[p_vars[i]] for i in range(dim)) - L_expr.subs(sol)
284                H_expr = sp.simplify(H_expr)
285                if return_symbol_only:
286                    H_expr = H_expr.subs(u, 0)
287                return H_expr, xi_vars
288            raise ValueError("Legendre inversion failed even with solve().")
289
290        # FENCHEL: symbolic attempt
291        # -----------------------------------------------------
292        #  Prevent symbolic Fenchel when L is non-differentiable
293        # -----------------------------------------------------
294        if method == "fenchel_symbolic":
295            if L_expr.has(sp.Abs) or L_expr.has(sp.sign) or any(
296                sp.diff(L_expr, p).has(sp.sign, sp.Abs) for p in p_vars
297            ):
298                raise ValueError(
299                    "Symbolic Fenchel not possible for nonsmooth L (Abs, sign). "
300                    "Use method='fenchel_numeric' instead."
301                )
302
303        if method == "fenchel_symbolic":
304            eqs = [sp.Eq(sp.diff(L_expr, p_vars[i]), xi_vars[i]) for i in range(dim)]
305            sol_list = sp.solve(eqs, p_vars, dict=True)
306            if sol_list:
307                candidates = []
308                for sol in sol_list:
309                    if isinstance(sol, tuple) and len(sol) == len(p_vars):
310                        sol = {p_vars[i]: sol[i] for i in range(len(p_vars))}
311                    S_expr = sum(xi_vars[i] * sol[p_vars[i]] for i in range(dim)) - L_expr.subs(sol)
312                    candidates.append(sp.simplify(S_expr))
313                H_candidates = sp.simplify(sp.Max(*candidates)) if len(candidates) > 1 else candidates[0]
314                if return_symbol_only:
315                    H_candidates = H_candidates.subs(u, 0)
316                return H_candidates, xi_vars
317            raise ValueError("Symbolic Fenchel conjugate not found; use method='fenchel_numeric' for numeric computation.")
318
319        # FENCHEL: numeric path
320        if method == "fenchel_numeric":
321            if fenchel_opts is None:
322                fenchel_opts = {}
323            if dim == 1:
324                p_bounds = fenchel_opts.get("p_bounds", (-10.0, 10.0))
325                n_grid = int(fenchel_opts.get("n_grid", 2001))
326                mode = fenchel_opts.get("mode", "auto")
327                scipy_multistart = int(fenchel_opts.get("scipy_multistart", 8))
328
329                # Build numeric L_func (try lambdify)
330                try:
331                    f_lamb = sp.lambdify((p_vars[0],), L_expr, "numpy")
332                    def L_func_scalar(p):
333                        return float(f_lamb(p))
334                except Exception:
335                    try:
336                        f_lamb = sp.lambdify(p_vars[0], L_expr, "numpy")
337                        def L_func_scalar(p):
338                            return float(f_lamb(p))
339                    except Exception:
340                        def L_func_scalar(p):
341                            return float(sp.N(L_expr.subs({p_vars[0]: p})))
342
343                H_numeric = LagrangianHamiltonianConverter._legendre_fenchel_1d_numeric_callable(
344                    L_func_scalar, p_bounds=p_bounds, n_grid=n_grid, mode=mode,
345                    scipy_multistart=scipy_multistart
346                )
347                H_func = sp.Function("H_numeric")
348                H_repr = H_func(xi_vars[0])
349                LagrangianHamiltonianConverter._numeric_cache[id(H_repr)] = H_numeric
350                return H_repr, xi_vars, H_numeric
351
352            else:
353                # dim == 2
354                p_bounds = fenchel_opts.get("p_bounds", [(-10.0, 10.0), (-10.0, 10.0)])
355                n_grid_per_dim = int(fenchel_opts.get("n_grid_per_dim", 41))
356                mode = fenchel_opts.get("mode", "auto")
357                scipy_multistart = int(fenchel_opts.get("scipy_multistart", 20))
358                multistart_restarts = int(fenchel_opts.get("multistart_restarts", 8))
359
360                f_lamb = None
361                try:
362                    f_lamb = sp.lambdify((p_vars[0], p_vars[1]), L_expr, "numpy")
363                    def L_func_nd(p):
364                        return float(f_lamb(float(p[0]), float(p[1])))
365                except Exception:
366                    try:
367                        f_lamb = sp.lambdify((p_vars,), L_expr, "numpy")
368                        def L_func_nd(p):
369                            return float(f_lamb(tuple(float(v) for v in p)))
370                    except Exception:
371                        def L_func_nd(p):
372                            subs_map = {p_vars[i]: float(p[i]) for i in range(2)}
373                            return float(sp.N(L_expr.subs(subs_map)))
374
375                H_numeric = LagrangianHamiltonianConverter._legendre_fenchel_nd_numeric_callable(
376                    L_func_nd, dim=2, p_bounds=(p_bounds[0], p_bounds[1]),
377                    n_grid_per_dim=n_grid_per_dim, mode=mode,
378                    scipy_multistart=scipy_multistart, multistart_restarts=multistart_restarts
379                )
380                H_func = sp.Function("H_numeric")
381                H_repr = H_func(*xi_vars)
382                LagrangianHamiltonianConverter._numeric_cache[id(H_repr)] = H_numeric
383                return H_repr, xi_vars, H_numeric
384
385        raise ValueError("Unknown method '{}'. Choose 'legendre', 'fenchel_symbolic' or 'fenchel_numeric'.".format(method))

Convert L(x,u,p) -> H(x,u,xi) with options for generalized Legendre (Fenchel).

Parameters:

  • method: "legendre" (default), "fenchel_symbolic", "fenchel_numeric"
  • fenchel_opts: dict with options for numeric fenchel
@staticmethod
def H_to_L(H_expr, coords, u, xi_vars, force=False):
387    @staticmethod
388    def H_to_L(H_expr, coords, u, xi_vars, force=False):
389        """
390        Inverse Legendre (classical). Does not attempt Fenchel inverse.
391        """
392        dim = len(coords)
393        if dim == 1:
394            p_vars = (sp.Symbol('p', real=True),)
395        elif dim == 2:
396            p_vars = (sp.Symbol('p_x', real=True), sp.Symbol('p_y', real=True))
397        else:
398            raise ValueError("Only 1D and 2D are supported.")
399
400        eqs = [sp.Eq(sp.diff(H_expr, xi_vars[i]), p_vars[i]) for i in range(dim)]
401        sol = sp.solve(eqs, xi_vars, dict=True)
402        if not sol:
403            if not force:
404                raise ValueError("Unable to symbolically solve p = ∂H/∂ξ for ξ. Use force=True.")
405            sol = sp.solve(eqs, xi_vars)
406        if not sol:
407            raise ValueError("Inverse Legendre transform failed; cannot find ξ(p).")
408        sol = sol[0] if isinstance(sol, list) else sol
409        if isinstance(sol, tuple) and len(sol) == len(xi_vars):
410            sol = {xi_vars[i]: sol[i] for i in range(len(xi_vars))}
411        if not isinstance(sol, dict):
412            if isinstance(sol, list) and sol and isinstance(sol[0], dict):
413                sol = sol[0]
414            else:
415                raise ValueError("Unexpected output from solve(); cannot construct ξ(p).")
416        L_expr = sum(sol[xi_vars[i]] * p_vars[i] for i in range(dim)) - H_expr.subs(sol)
417        return sp.simplify(L_expr), p_vars

Inverse Legendre (classical). Does not attempt Fenchel inverse.

class HamiltonianSymbolicConverter:
423class HamiltonianSymbolicConverter:
424    """
425    Symbolic converter between Hamiltonians and formal PDEs (psiOp).
426    """
427
428    @staticmethod
429    def decompose_hamiltonian(H_expr, xi_vars):
430        """
431        Decomposes the Hamiltonian into polynomial (local) and non-polynomial (nonlocal) parts.
432        The heuristic treats terms containing sqrt, Abs, or sign as nonlocal.
433        """
434        xi = xi_vars if isinstance(xi_vars, (tuple, list)) else (xi_vars,)
435        poly_terms, nonlocal_terms = 0, 0
436        H_expand = sp.expand(H_expr)
437        for term in H_expand.as_ordered_terms():
438            # Heuristic: treat terms containing sqrt/Abs/sign as nonlocal explicitly
439            # Check if the *current* 'term' (from the outer loop) has these functions.
440            # The original code had a scoping bug in the 'any' statement.
441            if any(func in term.free_symbols for func in [sp.sqrt, sp.Abs, sp.sign]) or \
442               term.has(sp.sqrt) or term.has(sp.Abs) or term.has(sp.sign):
443                # Alternative and more robust check:
444                # This checks if the specific 'term' object contains the specified functions.
445                nonlocal_terms += term
446            elif all(term.is_polynomial(xi_i) for xi_i in xi):
447                poly_terms += term
448            else:
449                nonlocal_terms += term
450        return sp.simplify(poly_terms), sp.simplify(nonlocal_terms)
451
452    @classmethod
453    def hamiltonian_to_symbolic_pde(cls, H_expr, coords, t, u, mode="schrodinger"):
454        dim = len(coords)
455        if dim == 1:
456            xi_vars = (sp.Symbol("xi", real=True),)
457        elif dim == 2:
458            xi_vars = (sp.Symbol("xi", real=True), sp.Symbol("eta", real=True))
459        else:
460            raise ValueError("Only 1D and 2D Hamiltonians are supported.")
461
462        H_poly, H_nonlocal = cls.decompose_hamiltonian(H_expr, xi_vars)
463        H_total = H_poly + H_nonlocal
464        psiOp_H_u = sp.Function("psiOp")(H_total, u)
465
466        if mode == "stationary":
467            E = sp.Symbol("E", real=True)
468            pde = sp.Eq(psiOp_H_u, E * u)
469            formal = "ψOp(H, u) = E u"
470        elif mode == "schrodinger":
471            pde = sp.Eq(sp.I * sp.Derivative(u, t), psiOp_H_u)
472            formal = "i ∂_t u = ψOp(H, u)"
473        elif mode == "wave":
474            pde = sp.Eq(sp.Derivative(u, (t, 2)), -psiOp_H_u)
475            formal = "∂_{tt} u + ψOp(H, u) = 0"
476        else:
477            raise ValueError("mode must be one of: 'stationary', 'schrodinger', 'wave'.")
478
479        coord_str = ", ".join(str(c) for c in coords)
480        xi_str = ", ".join(str(x) for x in xi_vars)
481        formal += f"   (H = H({coord_str}; {xi_str}))"
482
483        return {
484            "pde": sp.simplify(pde),
485            "H_poly": H_poly,
486            "H_nonlocal": H_nonlocal,
487            "formal_string": formal,
488            "mode": mode
489        }

Symbolic converter between Hamiltonians and formal PDEs (psiOp).

@staticmethod
def decompose_hamiltonian(H_expr, xi_vars):
428    @staticmethod
429    def decompose_hamiltonian(H_expr, xi_vars):
430        """
431        Decomposes the Hamiltonian into polynomial (local) and non-polynomial (nonlocal) parts.
432        The heuristic treats terms containing sqrt, Abs, or sign as nonlocal.
433        """
434        xi = xi_vars if isinstance(xi_vars, (tuple, list)) else (xi_vars,)
435        poly_terms, nonlocal_terms = 0, 0
436        H_expand = sp.expand(H_expr)
437        for term in H_expand.as_ordered_terms():
438            # Heuristic: treat terms containing sqrt/Abs/sign as nonlocal explicitly
439            # Check if the *current* 'term' (from the outer loop) has these functions.
440            # The original code had a scoping bug in the 'any' statement.
441            if any(func in term.free_symbols for func in [sp.sqrt, sp.Abs, sp.sign]) or \
442               term.has(sp.sqrt) or term.has(sp.Abs) or term.has(sp.sign):
443                # Alternative and more robust check:
444                # This checks if the specific 'term' object contains the specified functions.
445                nonlocal_terms += term
446            elif all(term.is_polynomial(xi_i) for xi_i in xi):
447                poly_terms += term
448            else:
449                nonlocal_terms += term
450        return sp.simplify(poly_terms), sp.simplify(nonlocal_terms)

Decomposes the Hamiltonian into polynomial (local) and non-polynomial (nonlocal) parts. The heuristic treats terms containing sqrt, Abs, or sign as nonlocal.

@classmethod
def hamiltonian_to_symbolic_pde(cls, H_expr, coords, t, u, mode='schrodinger'):
452    @classmethod
453    def hamiltonian_to_symbolic_pde(cls, H_expr, coords, t, u, mode="schrodinger"):
454        dim = len(coords)
455        if dim == 1:
456            xi_vars = (sp.Symbol("xi", real=True),)
457        elif dim == 2:
458            xi_vars = (sp.Symbol("xi", real=True), sp.Symbol("eta", real=True))
459        else:
460            raise ValueError("Only 1D and 2D Hamiltonians are supported.")
461
462        H_poly, H_nonlocal = cls.decompose_hamiltonian(H_expr, xi_vars)
463        H_total = H_poly + H_nonlocal
464        psiOp_H_u = sp.Function("psiOp")(H_total, u)
465
466        if mode == "stationary":
467            E = sp.Symbol("E", real=True)
468            pde = sp.Eq(psiOp_H_u, E * u)
469            formal = "ψOp(H, u) = E u"
470        elif mode == "schrodinger":
471            pde = sp.Eq(sp.I * sp.Derivative(u, t), psiOp_H_u)
472            formal = "i ∂_t u = ψOp(H, u)"
473        elif mode == "wave":
474            pde = sp.Eq(sp.Derivative(u, (t, 2)), -psiOp_H_u)
475            formal = "∂_{tt} u + ψOp(H, u) = 0"
476        else:
477            raise ValueError("mode must be one of: 'stationary', 'schrodinger', 'wave'.")
478
479        coord_str = ", ".join(str(c) for c in coords)
480        xi_str = ", ".join(str(x) for x in xi_vars)
481        formal += f"   (H = H({coord_str}; {xi_str}))"
482
483        return {
484            "pde": sp.simplify(pde),
485            "H_poly": H_poly,
486            "H_nonlocal": H_nonlocal,
487            "formal_string": formal,
488            "mode": mode
489        }
class SymbolGeometry:
101class SymbolGeometry:
102    """
103    Analyzes the geometric structure of a symbol H(x, ξ)
104    
105    This class computes:
106    - Hamiltonian flow (geodesics)
107    - Jacobian (focusing)
108    - Caustics (singularities)
109    - Periodic orbits
110    - Semiclassical spectrum
111    """
112    
113    def __init__(self, symbol: sp.Expr, x_sym: sp.Symbol, xi_sym: sp.Symbol):
114        """
115        Initialize with a symbolic Hamiltonian
116        
117        Parameters
118        ----------
119        symbol : sympy expression
120            The Hamiltonian H(x, ξ)
121        x_sym, xi_sym : sympy symbols
122            Position and momentum variables
123        """
124        self.H = symbol
125        self.x_sym = x_sym
126        self.xi_sym = xi_sym
127        
128        # Compute derivatives symbolically (DRY principle)
129        self._compute_derivatives()
130        
131        # Convert to numerical functions (cached)
132        self._lambdify_functions()
133    
134    def _compute_derivatives(self):
135        """Compute all necessary derivatives (DRY)"""
136        dH_x = sp.diff(self.H, self.x_sym)
137        self.dH_dx = _sanitize(dH_x)
138        dH_xi = sp.diff(self.H, self.xi_sym)
139        self.dH_dxi = _sanitize(dH_xi)
140        d2H_x2 = sp.diff(self.dH_dx, self.x_sym)
141        self.d2H_dx2 = _sanitize(d2H_x2)        
142        d2H_xi2 = sp.diff(self.dH_dxi, self.xi_sym)
143        self.d2H_dxi2 = _sanitize(d2H_xi2)        
144        d2H_xxi = sp.diff(self.dH_dx, self.xi_sym)
145        self.d2H_dxdxi = _sanitize(d2H_xxi)
146    
147    def _lambdify_functions(self):
148        """Convert symbolic expressions to numerical functions (DRY)"""
149        vars_tuple = (self.x_sym, self.xi_sym)
150        
151        self.f_H = sp.lambdify(vars_tuple, self.H, 'numpy')
152        self.f_dH_dx = sp.lambdify(vars_tuple, self.dH_dx, 'numpy')
153        self.f_dH_dxi = sp.lambdify(vars_tuple, self.dH_dxi, 'numpy')
154        self.f_d2H_dx2 = sp.lambdify(vars_tuple, self.d2H_dx2, 'numpy')
155        self.f_d2H_dxi2 = sp.lambdify(vars_tuple, self.d2H_dxi2, 'numpy')
156        self.f_d2H_dxdxi = sp.lambdify(vars_tuple, self.d2H_dxdxi, 'numpy')
157    
158    def compute_geodesic(self, x0: float, xi0: float, t_max: float, 
159                        n_points: int = 500) -> Geodesic:
160        """
161        Compute geodesic with Jacobian (for caustics detection)
162        
163        Solves the augmented system:
164        dx/dt = ∂H/∂ξ
165        dξ/dt = -∂H/∂x
166        dJ/dt = ∂²H/∂ξ² J + ∂²H/∂x∂ξ K  (variational equation)
167        dK/dt = -∂²H/∂x∂ξ J - ∂²H/∂x² K
168        
169        Parameters
170        ----------
171        x0, xi0 : float
172            Initial conditions
173        t_max : float
174            Final time
175        n_points : int
176            Number of points
177            
178        Returns
179        -------
180        Geodesic
181            Complete geodesic information
182        """
183        def system(t, z):
184            x, xi, J, K = z
185            try:
186                # Hamilton equations
187                dx = float(self.f_dH_dxi(x, xi))
188                dxi = float(-self.f_dH_dx(x, xi))
189                
190                # Variational equations (Jacobian evolution)
191                d2H_dxi2 = float(self.f_d2H_dxi2(x, xi))
192                d2H_dxdxi = float(self.f_d2H_dxdxi(x, xi))
193                d2H_dx2 = float(self.f_d2H_dx2(x, xi))
194                
195                dJ = d2H_dxi2 * J + d2H_dxdxi * K
196                dK = -d2H_dxdxi * J - d2H_dx2 * K
197                
198                return [dx, dxi, dJ, dK]
199            except:
200                return [0, 0, 0, 0]
201        
202        # Initial conditions: J(0)=0, K(0)=1 (standard initial condition)
203        z0 = [x0, xi0, 0.0, 1.0]
204        
205        sol = solve_ivp(
206            system, [0, t_max], z0,
207            t_eval=np.linspace(0, t_max, n_points),
208            method='DOP853',
209            rtol=1e-10, atol=1e-12
210        )
211        
212        # Compute energy along trajectory
213        H_traj = np.array([self.f_H(sol.y[0][i], sol.y[1][i]) 
214                          for i in range(len(sol.t))])
215        
216        return Geodesic(
217            t=sol.t,
218            x=sol.y[0],
219            xi=sol.y[1],
220            H=H_traj,
221            J=sol.y[2],
222            K=sol.y[3]
223        )
224    
225    def find_periodic_orbits(self, energy: float, 
226                            x_range: Tuple[float, float],
227                            xi_range: Tuple[float, float],
228                            n_attempts: int = 50,
229                            tol_period: float = 1e-3) -> List[PeriodicOrbit]:
230        """
231        Find periodic orbits at fixed energy
232        
233        Strategy: Sample energy surface H(x,ξ)=E and look for closed orbits
234        
235        Parameters
236        ----------
237        energy : float
238            Target energy level
239        x_range, xi_range : tuple
240            Search domain
241        n_attempts : int
242            Number of initial conditions to try
243        tol_period : float
244            Tolerance for periodicity detection
245            
246        Returns
247        -------
248        list of PeriodicOrbit
249            Found periodic orbits
250        """
251        orbits = []
252        x_samples = np.linspace(x_range[0], x_range[1], int(np.sqrt(n_attempts)))
253        
254        for x0_test in x_samples:
255            # Solve H(x0, ξ0) = E for ξ0
256            def energy_eq(xi0):
257                try:
258                    return self.f_H(x0_test, xi0) - energy
259                except:
260                    return 1e10
261            
262            xi_guesses = np.linspace(xi_range[0], xi_range[1], 5)
263            
264            for xi_guess in xi_guesses:
265                try:
266                    result = fsolve(energy_eq, xi_guess, full_output=True)
267                    
268                    if result[2] != 1:  # Check convergence
269                        continue
270                    
271                    xi0 = result[0][0]
272                    
273                    # Verify we're on energy surface
274                    if abs(self.f_H(x0_test, xi0) - energy) > 1e-6:
275                        continue
276                    
277                    # Integrate to detect periodicity
278                    T_max = 20
279                    geo = self.compute_geodesic(x0_test, xi0, T_max, 2000)
280                    
281                    # Find returns to initial point
282                    distances = np.sqrt((geo.x - x0_test)**2 + (geo.xi - xi0)**2)
283                    
284                    # Find local minima (except t=0)
285                    minima_idx = []
286                    for i in range(10, len(distances)-10):
287                        if (distances[i] < distances[i-1] and 
288                            distances[i] < distances[i+1] and
289                            distances[i] < tol_period):
290                            minima_idx.append(i)
291                    
292                    if minima_idx:
293                        idx_period = minima_idx[0]
294                        period = geo.t[idx_period]
295                        
296                        if period > 0.1 and distances[idx_period] < tol_period:
297                            # Compute action S = ∮ ξ dx
298                            x_cycle = geo.x[:idx_period+1]
299                            xi_cycle = geo.xi[:idx_period+1]
300                            t_cycle = geo.t[:idx_period+1]
301                            
302                            dx_dt = np.gradient(x_cycle, t_cycle)
303                            action = np.trapz(xi_cycle * dx_dt, t_cycle)
304                            
305                            # Compute stability (Lyapunov exponent)
306                            stability = self._compute_stability(x0_test, xi0, period)
307                            
308                            orbits.append(PeriodicOrbit(
309                                x0=x0_test,
310                                xi0=xi0,
311                                period=period,
312                                action=action,
313                                energy=energy,
314                                stability=stability,
315                                x_cycle=x_cycle,
316                                xi_cycle=xi_cycle,
317                                t_cycle=t_cycle
318                            ))
319                
320                except:
321                    continue
322        
323        # Remove duplicates
324        return self._remove_duplicate_orbits(orbits)
325    
326    def _compute_stability(self, x0: float, xi0: float, T: float) -> float:
327        """Compute Lyapunov exponent (orbit stability)"""
328        def linearized_system(t, z):
329            x, xi, dx, dxi = z
330            try:
331                vx = float(self.f_dH_dxi(x, xi))
332                vxi = float(-self.f_dH_dx(x, xi))
333                
334                # Linearization
335                A12 = float(self.f_d2H_dxi2(x, xi))
336                A21 = float(-self.f_d2H_dxdxi(x, xi))
337                
338                ddx = A12 * dxi
339                ddxi = A21 * dx
340                
341                return [vx, vxi, ddx, ddxi]
342            except:
343                return [0, 0, 0, 0]
344        
345        epsilon = 1e-6
346        z0 = [x0, xi0, epsilon, 0]
347        
348        sol = solve_ivp(linearized_system, [0, T], z0, method='DOP853', rtol=1e-10)
349        
350        if sol.success and len(sol.y[2]) > 0:
351            perturbation_final = np.sqrt(sol.y[2][-1]**2 + sol.y[3][-1]**2)
352            return np.log(perturbation_final / epsilon) / T
353        else:
354            return np.nan
355    
356    def _remove_duplicate_orbits(self, orbits: List[PeriodicOrbit]) -> List[PeriodicOrbit]:
357        """Remove duplicate periodic orbits"""
358        unique = []
359        for orb in orbits:
360            is_duplicate = False
361            for orb_unique in unique:
362                if (abs(orb.period - orb_unique.period) < 0.1 and
363                    abs(orb.action - orb_unique.action) < 0.1):
364                    is_duplicate = True
365                    break
366            if not is_duplicate:
367                unique.append(orb)
368        return unique
369    
370    def gutzwiller_trace_formula(self, periodic_orbits: List[PeriodicOrbit],
371                                 t_values: np.ndarray, hbar: float = 1.0) -> np.ndarray:
372        """
373        Gutzwiller trace formula (semiclassical)
374        
375        Tr[exp(-iHt/ℏ)] ≈ Σ_γ A_γ exp(iS_γ/ℏ - iπμ_γ/2)
376        
377        Parameters
378        ----------
379        periodic_orbits : list
380            List of periodic orbits
381        t_values : array
382            Time values
383        hbar : float
384            Reduced Planck constant
385            
386        Returns
387        -------
388        array
389            Trace as function of time
390        """
391        trace = np.zeros(len(t_values), dtype=complex)
392        
393        for orb in periodic_orbits:
394            T = orb.period
395            S = orb.action
396            lambda_stab = orb.stability
397            
398            # ✅ CORRECTION 1 : Plus de répétitions (jusqu'à 10)
399            for k in range(1, 11):  # 1 → 11 (au lieu de 5)
400                T_k = k * T
401                S_k = k * S
402                
403                # Stability factor
404                if not np.isnan(lambda_stab) and abs(lambda_stab) > 1e-6:
405                    det_factor = abs(2 * np.sinh(k * lambda_stab * T))
406                else:
407                    det_factor = 1.0
408                
409                if det_factor < 1e-10:
410                    det_factor = 1e-10  # Évite division par zéro
411                
412                # ✅ CORRECTION 2 : Amplitude normalisée
413                amplitude = T / np.sqrt(det_factor)
414                
415                # Maslov index (0 pour oscillateur harmonique)
416                mu = 0
417                
418                # ✅ CORRECTION 3 : Pic delta au lieu de sinc
419                # Utiliser une gaussienne étroite centrée sur T_k
420                sigma = T_k * 0.05  # Largeur 5% de la période
421                gauss = np.exp(-((t_values - T_k)**2) / (2 * sigma**2))
422                gauss /= (sigma * np.sqrt(2 * np.pi))  # Normalisation
423                
424                phase = S_k / hbar - np.pi * mu / 2
425                contribution = amplitude * gauss * np.exp(1j * phase)
426                
427                # ✅ CORRECTION 4 : Facteur d'amortissement pour grandes répétitions
428                damping = np.exp(-0.1 * k)  # Atténue les contributions lointaines
429                trace += contribution * damping
430        
431        return trace
432    
433    def semiclassical_spectrum(self, periodic_orbits: List[PeriodicOrbit],
434                              hbar: float = 1.0, 
435                              resolution: int = 4000) -> Spectrum:  # ✅ 1000 → 4000
436        """
437        Extract semiclassical spectrum via Fourier transform of trace
438        
439        Parameters
440        ----------
441        periodic_orbits : list
442            Periodic orbits
443        hbar : float
444            Reduced Planck constant
445        resolution : int
446            Number of points
447            
448        Returns
449        -------
450        Spectrum
451            Spectral information
452        """        
453        # ✅ Temps d'intégration plus long
454        t_max = 200 / hbar  # 50 → 200
455        t_values = np.linspace(0, t_max, resolution)
456        
457        trace = self.gutzwiller_trace_formula(periodic_orbits, t_values, hbar)
458        
459        # Fourier transform: t → E
460        energies_fft = fftfreq(len(t_values), d=t_values[1]-t_values[0]) * 2 * np.pi * hbar
461        spectrum_fft = fft(trace)
462        
463        return Spectrum(
464            energies=energies_fft,
465            intensity=np.abs(spectrum_fft),
466            trace_t=t_values,
467            trace=trace
468        )

Analyzes the geometric structure of a symbol H(x, ξ)

This class computes:

  • Hamiltonian flow (geodesics)
  • Jacobian (focusing)
  • Caustics (singularities)
  • Periodic orbits
  • Semiclassical spectrum
SymbolGeometry( symbol: sympy.core.expr.Expr, x_sym: sympy.core.symbol.Symbol, xi_sym: sympy.core.symbol.Symbol)
113    def __init__(self, symbol: sp.Expr, x_sym: sp.Symbol, xi_sym: sp.Symbol):
114        """
115        Initialize with a symbolic Hamiltonian
116        
117        Parameters
118        ----------
119        symbol : sympy expression
120            The Hamiltonian H(x, ξ)
121        x_sym, xi_sym : sympy symbols
122            Position and momentum variables
123        """
124        self.H = symbol
125        self.x_sym = x_sym
126        self.xi_sym = xi_sym
127        
128        # Compute derivatives symbolically (DRY principle)
129        self._compute_derivatives()
130        
131        # Convert to numerical functions (cached)
132        self._lambdify_functions()

Initialize with a symbolic Hamiltonian

Parameters

symbol : sympy expression The Hamiltonian H(x, ξ) x_sym, xi_sym : sympy symbols Position and momentum variables

H
x_sym
xi_sym
def compute_geodesic( self, x0: float, xi0: float, t_max: float, n_points: int = 500) -> src.geometry_1d.Geodesic:
158    def compute_geodesic(self, x0: float, xi0: float, t_max: float, 
159                        n_points: int = 500) -> Geodesic:
160        """
161        Compute geodesic with Jacobian (for caustics detection)
162        
163        Solves the augmented system:
164        dx/dt = ∂H/∂ξ
165        dξ/dt = -∂H/∂x
166        dJ/dt = ∂²H/∂ξ² J + ∂²H/∂x∂ξ K  (variational equation)
167        dK/dt = -∂²H/∂x∂ξ J - ∂²H/∂x² K
168        
169        Parameters
170        ----------
171        x0, xi0 : float
172            Initial conditions
173        t_max : float
174            Final time
175        n_points : int
176            Number of points
177            
178        Returns
179        -------
180        Geodesic
181            Complete geodesic information
182        """
183        def system(t, z):
184            x, xi, J, K = z
185            try:
186                # Hamilton equations
187                dx = float(self.f_dH_dxi(x, xi))
188                dxi = float(-self.f_dH_dx(x, xi))
189                
190                # Variational equations (Jacobian evolution)
191                d2H_dxi2 = float(self.f_d2H_dxi2(x, xi))
192                d2H_dxdxi = float(self.f_d2H_dxdxi(x, xi))
193                d2H_dx2 = float(self.f_d2H_dx2(x, xi))
194                
195                dJ = d2H_dxi2 * J + d2H_dxdxi * K
196                dK = -d2H_dxdxi * J - d2H_dx2 * K
197                
198                return [dx, dxi, dJ, dK]
199            except:
200                return [0, 0, 0, 0]
201        
202        # Initial conditions: J(0)=0, K(0)=1 (standard initial condition)
203        z0 = [x0, xi0, 0.0, 1.0]
204        
205        sol = solve_ivp(
206            system, [0, t_max], z0,
207            t_eval=np.linspace(0, t_max, n_points),
208            method='DOP853',
209            rtol=1e-10, atol=1e-12
210        )
211        
212        # Compute energy along trajectory
213        H_traj = np.array([self.f_H(sol.y[0][i], sol.y[1][i]) 
214                          for i in range(len(sol.t))])
215        
216        return Geodesic(
217            t=sol.t,
218            x=sol.y[0],
219            xi=sol.y[1],
220            H=H_traj,
221            J=sol.y[2],
222            K=sol.y[3]
223        )

Compute geodesic with Jacobian (for caustics detection)

Solves the augmented system: dx/dt = ∂H/∂ξ dξ/dt = -∂H/∂x dJ/dt = ∂²H/∂ξ² J + ∂²H/∂x∂ξ K (variational equation) dK/dt = -∂²H/∂x∂ξ J - ∂²H/∂x² K

Parameters

x0, xi0 : float Initial conditions t_max : float Final time n_points : int Number of points

Returns

Geodesic Complete geodesic information

def find_periodic_orbits( self, energy: float, x_range: Tuple[float, float], xi_range: Tuple[float, float], n_attempts: int = 50, tol_period: float = 0.001) -> List[src.geometry_1d.PeriodicOrbit]:
225    def find_periodic_orbits(self, energy: float, 
226                            x_range: Tuple[float, float],
227                            xi_range: Tuple[float, float],
228                            n_attempts: int = 50,
229                            tol_period: float = 1e-3) -> List[PeriodicOrbit]:
230        """
231        Find periodic orbits at fixed energy
232        
233        Strategy: Sample energy surface H(x,ξ)=E and look for closed orbits
234        
235        Parameters
236        ----------
237        energy : float
238            Target energy level
239        x_range, xi_range : tuple
240            Search domain
241        n_attempts : int
242            Number of initial conditions to try
243        tol_period : float
244            Tolerance for periodicity detection
245            
246        Returns
247        -------
248        list of PeriodicOrbit
249            Found periodic orbits
250        """
251        orbits = []
252        x_samples = np.linspace(x_range[0], x_range[1], int(np.sqrt(n_attempts)))
253        
254        for x0_test in x_samples:
255            # Solve H(x0, ξ0) = E for ξ0
256            def energy_eq(xi0):
257                try:
258                    return self.f_H(x0_test, xi0) - energy
259                except:
260                    return 1e10
261            
262            xi_guesses = np.linspace(xi_range[0], xi_range[1], 5)
263            
264            for xi_guess in xi_guesses:
265                try:
266                    result = fsolve(energy_eq, xi_guess, full_output=True)
267                    
268                    if result[2] != 1:  # Check convergence
269                        continue
270                    
271                    xi0 = result[0][0]
272                    
273                    # Verify we're on energy surface
274                    if abs(self.f_H(x0_test, xi0) - energy) > 1e-6:
275                        continue
276                    
277                    # Integrate to detect periodicity
278                    T_max = 20
279                    geo = self.compute_geodesic(x0_test, xi0, T_max, 2000)
280                    
281                    # Find returns to initial point
282                    distances = np.sqrt((geo.x - x0_test)**2 + (geo.xi - xi0)**2)
283                    
284                    # Find local minima (except t=0)
285                    minima_idx = []
286                    for i in range(10, len(distances)-10):
287                        if (distances[i] < distances[i-1] and 
288                            distances[i] < distances[i+1] and
289                            distances[i] < tol_period):
290                            minima_idx.append(i)
291                    
292                    if minima_idx:
293                        idx_period = minima_idx[0]
294                        period = geo.t[idx_period]
295                        
296                        if period > 0.1 and distances[idx_period] < tol_period:
297                            # Compute action S = ∮ ξ dx
298                            x_cycle = geo.x[:idx_period+1]
299                            xi_cycle = geo.xi[:idx_period+1]
300                            t_cycle = geo.t[:idx_period+1]
301                            
302                            dx_dt = np.gradient(x_cycle, t_cycle)
303                            action = np.trapz(xi_cycle * dx_dt, t_cycle)
304                            
305                            # Compute stability (Lyapunov exponent)
306                            stability = self._compute_stability(x0_test, xi0, period)
307                            
308                            orbits.append(PeriodicOrbit(
309                                x0=x0_test,
310                                xi0=xi0,
311                                period=period,
312                                action=action,
313                                energy=energy,
314                                stability=stability,
315                                x_cycle=x_cycle,
316                                xi_cycle=xi_cycle,
317                                t_cycle=t_cycle
318                            ))
319                
320                except:
321                    continue
322        
323        # Remove duplicates
324        return self._remove_duplicate_orbits(orbits)

Find periodic orbits at fixed energy

Strategy: Sample energy surface H(x,ξ)=E and look for closed orbits

Parameters

energy : float Target energy level x_range, xi_range : tuple Search domain n_attempts : int Number of initial conditions to try tol_period : float Tolerance for periodicity detection

Returns

list of PeriodicOrbit Found periodic orbits

def gutzwiller_trace_formula( self, periodic_orbits: List[src.geometry_1d.PeriodicOrbit], t_values: numpy.ndarray, hbar: float = 1.0) -> numpy.ndarray:
370    def gutzwiller_trace_formula(self, periodic_orbits: List[PeriodicOrbit],
371                                 t_values: np.ndarray, hbar: float = 1.0) -> np.ndarray:
372        """
373        Gutzwiller trace formula (semiclassical)
374        
375        Tr[exp(-iHt/ℏ)] ≈ Σ_γ A_γ exp(iS_γ/ℏ - iπμ_γ/2)
376        
377        Parameters
378        ----------
379        periodic_orbits : list
380            List of periodic orbits
381        t_values : array
382            Time values
383        hbar : float
384            Reduced Planck constant
385            
386        Returns
387        -------
388        array
389            Trace as function of time
390        """
391        trace = np.zeros(len(t_values), dtype=complex)
392        
393        for orb in periodic_orbits:
394            T = orb.period
395            S = orb.action
396            lambda_stab = orb.stability
397            
398            # ✅ CORRECTION 1 : Plus de répétitions (jusqu'à 10)
399            for k in range(1, 11):  # 1 → 11 (au lieu de 5)
400                T_k = k * T
401                S_k = k * S
402                
403                # Stability factor
404                if not np.isnan(lambda_stab) and abs(lambda_stab) > 1e-6:
405                    det_factor = abs(2 * np.sinh(k * lambda_stab * T))
406                else:
407                    det_factor = 1.0
408                
409                if det_factor < 1e-10:
410                    det_factor = 1e-10  # Évite division par zéro
411                
412                # ✅ CORRECTION 2 : Amplitude normalisée
413                amplitude = T / np.sqrt(det_factor)
414                
415                # Maslov index (0 pour oscillateur harmonique)
416                mu = 0
417                
418                # ✅ CORRECTION 3 : Pic delta au lieu de sinc
419                # Utiliser une gaussienne étroite centrée sur T_k
420                sigma = T_k * 0.05  # Largeur 5% de la période
421                gauss = np.exp(-((t_values - T_k)**2) / (2 * sigma**2))
422                gauss /= (sigma * np.sqrt(2 * np.pi))  # Normalisation
423                
424                phase = S_k / hbar - np.pi * mu / 2
425                contribution = amplitude * gauss * np.exp(1j * phase)
426                
427                # ✅ CORRECTION 4 : Facteur d'amortissement pour grandes répétitions
428                damping = np.exp(-0.1 * k)  # Atténue les contributions lointaines
429                trace += contribution * damping
430        
431        return trace

Gutzwiller trace formula (semiclassical)

Tr[exp(-iHt/ℏ)] ≈ Σ_γ A_γ exp(iS_γ/ℏ - iπμ_γ/2)

Parameters

periodic_orbits : list List of periodic orbits t_values : array Time values hbar : float Reduced Planck constant

Returns

array Trace as function of time

def semiclassical_spectrum( self, periodic_orbits: List[src.geometry_1d.PeriodicOrbit], hbar: float = 1.0, resolution: int = 4000) -> src.geometry_1d.Spectrum:
433    def semiclassical_spectrum(self, periodic_orbits: List[PeriodicOrbit],
434                              hbar: float = 1.0, 
435                              resolution: int = 4000) -> Spectrum:  # ✅ 1000 → 4000
436        """
437        Extract semiclassical spectrum via Fourier transform of trace
438        
439        Parameters
440        ----------
441        periodic_orbits : list
442            Periodic orbits
443        hbar : float
444            Reduced Planck constant
445        resolution : int
446            Number of points
447            
448        Returns
449        -------
450        Spectrum
451            Spectral information
452        """        
453        # ✅ Temps d'intégration plus long
454        t_max = 200 / hbar  # 50 → 200
455        t_values = np.linspace(0, t_max, resolution)
456        
457        trace = self.gutzwiller_trace_formula(periodic_orbits, t_values, hbar)
458        
459        # Fourier transform: t → E
460        energies_fft = fftfreq(len(t_values), d=t_values[1]-t_values[0]) * 2 * np.pi * hbar
461        spectrum_fft = fft(trace)
462        
463        return Spectrum(
464            energies=energies_fft,
465            intensity=np.abs(spectrum_fft),
466            trace_t=t_values,
467            trace=trace
468        )

Extract semiclassical spectrum via Fourier transform of trace

Parameters

periodic_orbits : list Periodic orbits hbar : float Reduced Planck constant resolution : int Number of points

Returns

Spectrum Spectral information

class SymbolVisualizer:
474class SymbolVisualizer:
475    """
476    Comprehensive visualization of symbol geometry
477    
478    Produces 15 panels showing:
479    1. Hamiltonian surface (3D)
480    2. Energy level sets (phase space foliation)
481    3. Hamiltonian vector field
482    4. Group velocity ∂H/∂ξ
483    5. Spatial projection (caustics)
484    6. Jacobian (focusing measure)
485    7. Curvature (focusing tendency)
486    8. Energy conservation
487    9. Periodic orbits (phase space)
488    10. Period-energy diagram
489    11. EBK quantization
490    12. Trace formula
491    13. Semiclassical spectrum
492    14. Orbit stability
493    15. Level spacing distribution
494    """
495    
496    def __init__(self, geometry: SymbolGeometry):
497        """
498        Parameters
499        ----------
500        geometry : SymbolGeometry
501            Initialized geometry engine
502        """
503        self.geo = geometry
504    
505    def visualize_complete(self, 
506                          x_range: Tuple[float, float],
507                          xi_range: Tuple[float, float],
508                          geodesics_params: List[Tuple],
509                          E_range: Optional[Tuple[float, float]] = None,
510                          hbar: float = 1.0,
511                          resolution: int = 100) -> Tuple:
512        """
513        Create complete geometric atlas
514        
515        Parameters
516        ----------
517        x_range, xi_range : tuple
518            Domain limits
519        geodesics_params : list of tuples
520            Each tuple: (x0, xi0, t_max, color)
521        E_range : tuple, optional
522            Energy range for spectral analysis
523        hbar : float
524            Reduced Planck constant
525        resolution : int
526            Grid resolution
527            
528        Returns
529        -------
530        fig, geodesics, periodic_orbits, spectrum
531        """
532        # Compute grid
533        x_grid = np.linspace(x_range[0], x_range[1], resolution)
534        xi_grid = np.linspace(xi_range[0], xi_range[1], resolution)
535        X, Xi = np.meshgrid(x_grid, xi_grid)
536        
537        # Evaluate Hamiltonian and derivatives on grid
538        grids = self._evaluate_grids(X, Xi)
539        
540        # Compute geodesics
541        geodesics = self._compute_geodesics(geodesics_params)
542        
543        # Find periodic orbits (if E_range specified)
544        periodic_orbits = []
545        spectrum = None
546        if E_range:
547            energies = np.linspace(E_range[0], E_range[1], 8)
548            for E in energies:
549                orbits = self.geo.find_periodic_orbits(E, x_range, xi_range)
550                periodic_orbits.extend(orbits)
551            
552            if periodic_orbits:
553                spectrum = self.geo.semiclassical_spectrum(periodic_orbits, hbar)
554        
555        # Create figure
556        fig = self._create_figure(X, Xi, grids, geodesics, periodic_orbits, spectrum, hbar)
557        
558        return fig, geodesics, periodic_orbits, spectrum
559    
560    def _evaluate_grids(self, X: np.ndarray, Xi: np.ndarray) -> Dict:
561        """Evaluate all necessary fields on grid (DRY)"""
562        grids = {}
563        
564        for name, func in [
565            ('H', self.geo.f_H),
566            ('dH_dxi', self.geo.f_dH_dxi),
567            ('dH_dx', self.geo.f_dH_dx),
568            ('d2H_dxdxi', self.geo.f_d2H_dxdxi)
569        ]:
570            grid = np.zeros_like(X)
571            for i in range(X.shape[0]):
572                for j in range(X.shape[1]):
573                    try:
574                        grid[i, j] = func(X[i, j], Xi[i, j])
575                    except:
576                        grid[i, j] = np.nan
577            grids[name] = grid
578        
579        return grids
580    
581    def _compute_geodesics(self, params: List[Tuple]) -> List[Geodesic]:
582        """Compute all geodesics"""
583        geodesics = []
584        for p in params:
585            x0, xi0, t_max = p[:3]
586            geo = self.geo.compute_geodesic(x0, xi0, t_max)
587            geo.color = p[3] if len(p) > 3 else 'blue'
588            geodesics.append(geo)
589        return geodesics
590    
591    def _create_figure(self, X, Xi, grids, geodesics, periodic_orbits, spectrum, hbar):
592        """Create the complete visualization figure"""
593        fig = plt.figure(figsize=(24, 18))
594        
595        # Panel 1-8: Geometry
596        self._plot_hamiltonian_surface(fig, X, Xi, grids['H'], geodesics, 1)
597        self._plot_level_sets(fig, X, Xi, grids['H'], geodesics, 2)
598        self._plot_vector_field(fig, X, Xi, grids, geodesics, 3)
599        self._plot_group_velocity(fig, X, Xi, grids['dH_dxi'], geodesics, 4)
600        self._plot_spatial_projection(fig, geodesics, 5)
601        self._plot_jacobian(fig, geodesics, 6)
602        self._plot_curvature(fig, X, Xi, grids['d2H_dxdxi'], geodesics, 7)
603        self._plot_energy_conservation(fig, geodesics, 8)
604        
605        # Panel 9-15: Spectral analysis
606        if periodic_orbits:
607            self._plot_periodic_orbits(fig, X, Xi, grids['H'], periodic_orbits, 9)
608            self._plot_period_energy(fig, periodic_orbits, 10)
609            self._plot_ebk_quantization(fig, periodic_orbits, hbar, 11)
610            
611            if spectrum:
612                self._plot_trace_formula(fig, spectrum, 12)
613                self._plot_spectrum(fig, spectrum, 13)
614                self._plot_stability(fig, periodic_orbits, 14)
615                self._plot_level_spacing(fig, spectrum, 15)
616        
617        plt.suptitle(f'Geometric and Semiclassical Atlas: H = {self.geo.H}',
618                     fontsize=18, fontweight='bold', y=0.995)
619        plt.tight_layout(rect=[0, 0, 1, 0.98])
620
621        
622        return fig
623    
624    # Individual plotting methods (KISS principle: each does one thing)
625    
626    def _plot_hamiltonian_surface(self, fig, X, Xi, H_grid, geodesics, panel):
627        """Panel 1: Hamiltonian surface in 3D"""
628        ax = fig.add_subplot(3, 5, panel, projection='3d')
629        ax.plot_surface(X, Xi, H_grid, cmap='viridis', alpha=0.8, 
630                        linewidth=0, antialiased=True)
631        
632        for geo in geodesics:
633            color = getattr(geo, 'color', 'red')
634            ax.plot(geo.x, geo.xi, geo.H, color=color, linewidth=3)
635            ax.scatter([geo.x[0]], [geo.xi[0]], [geo.H[0]], 
636                       color=color, s=100, edgecolors='black', linewidths=2)
637        
638        ax.set_xlabel('x')
639        ax.set_ylabel('ξ')
640        ax.set_zlabel('H(x,ξ)')
641        ax.set_title('Hamiltonian Surface\n+ Geodesics', fontweight='bold')
642        ax.view_init(elev=25, azim=45)
643        
644        # 🔧 Ajustements pour taille cohérente
645        ax.set_box_aspect((1, 1, 0.6))   # équilibre visuel (x, ξ, H)
646        ax.margins(0)                    # supprime marges internes
647        ax.set_proj_type('ortho')        # projection orthographique = moins de distorsion
648    
649    def _plot_level_sets(self, fig, X, Xi, H_grid, geodesics, panel):
650        """Panel 2: Energy level sets (symplectic foliation)"""
651        ax = fig.add_subplot(3, 5, panel)
652        levels = np.linspace(np.nanmin(H_grid), np.nanmax(H_grid), 20)
653        contour = ax.contour(X, Xi, H_grid, levels=levels, cmap='viridis')
654        ax.clabel(contour, inline=True, fontsize=8)
655        
656        for geo in geodesics:
657            color = getattr(geo, 'color', 'red')
658            ax.plot(geo.x, geo.xi, color=color, linewidth=2.5)
659        
660        ax.set_xlabel('x')
661        ax.set_ylabel('ξ')
662        ax.set_title('Level Sets H=const\nSymplectic Foliation', fontweight='bold')
663        ax.grid(True, alpha=0.3)
664        ax.set_aspect('auto')     
665        ax.margins(0.05)          
666    
667    
668    def _plot_vector_field(self, fig, X, Xi, grids, geodesics, panel):
669        """Panel 3: Hamiltonian vector field"""
670        ax = fig.add_subplot(3, 5, panel)
671        
672        step = max(1, X.shape[0] // 20)
673        X_sub = X[::step, ::step]
674        Xi_sub = Xi[::step, ::step]
675        vx = grids['dH_dxi'][::step, ::step]
676        vy = -grids['dH_dx'][::step, ::step]
677        
678        magnitude = np.sqrt(vx**2 + vy**2)
679        magnitude[magnitude == 0] = 1
680        
681        ax.quiver(X_sub, Xi_sub, vx/magnitude, vy/magnitude,
682                 magnitude, cmap='plasma', alpha=0.7)
683        
684        for geo in geodesics:
685            color = getattr(geo, 'color', 'cyan')
686            ax.plot(geo.x, geo.xi, color=color, linewidth=3)
687        
688        ax.set_xlabel('x')
689        ax.set_ylabel('ξ')
690        ax.set_title('Hamiltonian Vector Field\n(Infinitesimal generator)', fontweight='bold')
691        ax.grid(True, alpha=0.3)
692    
693    def _plot_group_velocity(self, fig, X, Xi, dH_dxi, geodesics, panel):
694        """Panel 4: Group velocity ∂H/∂ξ"""
695        ax = fig.add_subplot(3, 5, panel)
696        
697        im = ax.contourf(X, Xi, dH_dxi, levels=30, cmap='RdBu_r')
698        plt.colorbar(im, ax=ax, label='∂H/∂ξ')
699        ax.contour(X, Xi, dH_dxi, levels=[0], colors='black', 
700                  linewidths=2, linestyles='--')
701        
702        for geo in geodesics:
703            ax.plot(geo.x, geo.xi, color='yellow', linewidth=2)
704        
705        ax.set_xlabel('x')
706        ax.set_ylabel('ξ')
707        ax.set_title('Group Velocity v_g = ∂H/∂ξ\n(Wave propagation speed)', fontweight='bold')
708        ax.grid(True, alpha=0.3)
709    
710    def _plot_spatial_projection(self, fig, geodesics, panel):
711        """Panel 5: Spatial projection (with caustics)"""
712        ax = fig.add_subplot(3, 5, panel)
713        
714        for geo in geodesics:
715            color = getattr(geo, 'color', 'blue')
716            ax.plot(geo.x, geo.t, color=color, linewidth=2.5)
717            
718            # Mark caustics
719            caust_idx = geo.caustics
720            if len(caust_idx) > 0:
721                ax.scatter(geo.x[caust_idx], geo.t[caust_idx],
722                          color='red', s=150, marker='*', zorder=15,
723                          edgecolors='darkred', linewidths=1.5)
724        
725        ax.set_xlabel('x')
726        ax.set_ylabel('t')
727        ax.set_title('Spatial Projection\n★ = Caustics', fontweight='bold')
728        ax.grid(True, alpha=0.3)
729    
730    def _plot_jacobian(self, fig, geodesics, panel):
731        """Panel 6: Jacobian (focusing measure)"""
732        ax = fig.add_subplot(3, 5, panel)
733        
734        for geo in geodesics:
735            color = getattr(geo, 'color', 'blue')
736            ax.plot(geo.t, geo.J, color=color, linewidth=2.5)
737        
738        ax.axhline(0, color='red', linestyle='--', linewidth=2, alpha=0.7)
739        ax.set_xlabel('t')
740        ax.set_ylabel('J = ∂x/∂ξ₀')
741        ax.set_title('Jacobian (Focusing)\nJ→0: rays converge', fontweight='bold')
742        ax.grid(True, alpha=0.3)
743    
744    def _plot_curvature(self, fig, X, Xi, curvature, geodesics, panel):
745        """Panel 7: Sectional curvature"""
746        ax = fig.add_subplot(3, 5, panel)
747        
748        im = ax.contourf(X, Xi, curvature, levels=30, cmap='seismic')
749        plt.colorbar(im, ax=ax, label='∂²H/∂x∂ξ')
750        
751        for geo in geodesics:
752            ax.plot(geo.x, geo.xi, color='lime', linewidth=2)
753        
754        ax.set_xlabel('x')
755        ax.set_ylabel('ξ')
756        ax.set_title('Sectional Curvature\nRed>0: focusing | Blue<0: defocusing', fontweight='bold')
757        ax.grid(True, alpha=0.3)
758    
759    def _plot_energy_conservation(self, fig, geodesics, panel):
760        """Panel 8: Energy conservation (integration quality)"""
761        ax = fig.add_subplot(3, 5, panel)
762        
763        for geo in geodesics:
764            color = getattr(geo, 'color', 'blue')
765            H_variation = (geo.H - geo.H[0]) / (np.abs(geo.H[0]) + 1e-10)
766            ax.semilogy(geo.t, np.abs(H_variation) + 1e-16,
767                       color=color, linewidth=2.5, label=f'E₀={geo.H[0]:.2f}')
768        
769        ax.set_xlabel('t')
770        ax.set_ylabel('|ΔH/H₀|')
771        ax.set_title('Energy Conservation\n(Numerical quality)', fontweight='bold')
772        ax.legend(fontsize=9)
773        ax.grid(True, alpha=0.3, which='both')
774    
775    def _plot_periodic_orbits(self, fig, X, Xi, H_grid, periodic_orbits, panel):
776        """Panel 9: Periodic orbits in phase space"""
777        ax = fig.add_subplot(3, 5, panel)
778        
779        # Energy level sets
780        energies = np.unique([orb.energy for orb in periodic_orbits])
781        contour = ax.contour(X, Xi, H_grid, levels=energies, 
782                            cmap='viridis', linewidths=1.5, alpha=0.6)
783        
784        # Periodic orbits
785        colors_orb = plt.cm.rainbow(np.linspace(0, 1, len(periodic_orbits)))
786        for idx, orb in enumerate(periodic_orbits):
787            ax.plot(orb.x_cycle, orb.xi_cycle, 
788                   color=colors_orb[idx], linewidth=3, alpha=0.8)
789            ax.scatter([orb.x0], [orb.xi0], color=colors_orb[idx], 
790                      s=100, marker='o', edgecolors='black', linewidths=2, zorder=10)
791        
792        ax.set_xlabel('x')
793        ax.set_ylabel('ξ')
794        ax.set_title('Periodic Orbits\n(Phase space)', fontweight='bold')
795        ax.grid(True, alpha=0.3)
796        ax.set_aspect('equal')
797    
798    def _plot_period_energy(self, fig, periodic_orbits, panel):
799        """Panel 10: Period-Energy relation"""
800        ax = fig.add_subplot(3, 5, panel)
801        
802        E_orb = [orb.energy for orb in periodic_orbits]
803        T_orb = [orb.period for orb in periodic_orbits]
804        S_orb = [orb.action for orb in periodic_orbits]
805        
806        scatter = ax.scatter(E_orb, T_orb, c=S_orb, s=150,
807                           cmap='plasma', edgecolors='black', linewidths=1.5)
808        plt.colorbar(scatter, ax=ax, label='Action S')
809        
810        ax.set_xlabel('Energy E')
811        ax.set_ylabel('Period T')
812        ax.set_title('Period-Energy Diagram\nT(E)', fontweight='bold')
813        ax.grid(True, alpha=0.3)
814    
815    def _plot_ebk_quantization(self, fig, periodic_orbits, hbar, panel):
816        """Panel 11: EBK quantization (Einstein-Brillouin-Keller)"""
817        ax = fig.add_subplot(3, 5, panel)
818        
819        E_orb = [orb.energy for orb in periodic_orbits]
820        S_orb = [orb.action for orb in periodic_orbits]
821        T_orb = [orb.period for orb in periodic_orbits]
822        
823        scatter = ax.scatter(E_orb, S_orb, s=150, c=T_orb, cmap='cool',
824                           edgecolors='black', linewidths=1.5)
825        plt.colorbar(scatter, ax=ax, label='Period T')
826        
827        # EBK quantization rules: S = 2πℏ(n + α)
828        E_max = max(E_orb) if E_orb else 10
829        for n in range(15):
830            S_quant = 2 * np.pi * hbar * (n + 0.25)  # α ≈ 1/4 for 1D
831            if S_quant < max(S_orb) if S_orb else 10:
832                ax.axhline(S_quant, color='red', linestyle='--', alpha=0.3, linewidth=1)
833                ax.text(min(E_orb) if E_orb else 0, S_quant, f'n={n}',
834                       fontsize=8, color='red', va='bottom')
835        
836        ax.set_xlabel('Energy E')
837        ax.set_ylabel('Action S')
838        ax.set_title('EBK Quantization\nS = 2πℏ(n+α)', fontweight='bold')
839        ax.grid(True, alpha=0.3)
840    
841    def _plot_trace_formula(self, fig, spectrum, panel):
842        """Panel 12: Gutzwiller trace formula"""
843        ax = fig.add_subplot(3, 5, panel)
844        
845        # Plot only first part for clarity
846        n_plot = min(500, len(spectrum.trace_t))
847        ax.plot(spectrum.trace_t[:n_plot], np.real(spectrum.trace[:n_plot]),
848               'b-', linewidth=1.5, label='Re[Tr]')
849        ax.plot(spectrum.trace_t[:n_plot], np.imag(spectrum.trace[:n_plot]),
850               'r-', linewidth=1.5, alpha=0.7, label='Im[Tr]')
851        
852        ax.set_xlabel('Time t')
853        ax.set_ylabel('Tr[exp(-iHt/ℏ)]')
854        ax.set_title('Gutzwiller Trace Formula\nΣ_γ A_γ exp(iS_γ/ℏ)', fontweight='bold')
855        ax.legend()
856        ax.grid(True, alpha=0.3)
857    
858    def _plot_spectrum(self, fig, spectrum, panel):
859        """Panel 13: Semiclassical spectrum"""
860        ax = fig.add_subplot(3, 5, panel)
861        
862        # Only positive energies
863        mask = spectrum.energies > 0
864        E_positive = spectrum.energies[mask]
865        I_positive = spectrum.intensity[mask]
866        
867        # Detect peaks
868        peaks, properties = find_peaks(I_positive, 
869                                      height=np.max(I_positive)*0.1,
870                                      distance=20)
871        
872        ax.plot(E_positive, I_positive, 'b-', linewidth=1.5)
873        ax.plot(E_positive[peaks], I_positive[peaks],
874               'ro', markersize=10, label='Energy levels')
875        
876        # Annotate first levels
877        for i, peak in enumerate(peaks[:10]):
878            E_level = E_positive[peak]
879            ax.text(E_level, I_positive[peak], f'E_{i}',
880                   fontsize=9, ha='center', va='bottom')
881        
882        ax.set_xlabel('Energy E')
883        ax.set_ylabel('Spectral density')
884        ax.set_title('Semiclassical Spectrum\n(Fourier transform of trace)', fontweight='bold')
885        ax.legend()
886        ax.grid(True, alpha=0.3)
887    
888    def _plot_stability(self, fig, periodic_orbits, panel):
889        """Panel 14: Orbit stability (Lyapunov exponents)"""
890        ax = fig.add_subplot(3, 5, panel)
891        
892        stab = [orb.stability for orb in periodic_orbits]
893        E_stab = [orb.energy for orb in periodic_orbits]
894        T_stab = [orb.period for orb in periodic_orbits]
895        
896        scatter = ax.scatter(E_stab, stab, s=150, c=T_stab, cmap='autumn',
897                           edgecolors='black', linewidths=1.5)
898        plt.colorbar(scatter, ax=ax, label='Period T')
899        ax.axhline(0, color='green', linestyle='--', linewidth=2,
900                  label='Marginal stability')
901        
902        ax.set_xlabel('Energy E')
903        ax.set_ylabel('Lyapunov exponent λ')
904        ax.set_title('Orbit Stability\nλ>0: unstable | λ<0: stable', fontweight='bold')
905        ax.legend()
906        ax.grid(True, alpha=0.3)
907    
908    def _plot_level_spacing(self, fig, spectrum, panel):
909        """Panel 15: Level spacing distribution (integrability test)"""
910        ax = fig.add_subplot(3, 5, panel)
911        
912        # Extract energy levels
913        mask = spectrum.energies > 0
914        E_positive = spectrum.energies[mask]
915        I_positive = spectrum.intensity[mask]
916        
917        peaks, _ = find_peaks(I_positive, height=np.max(I_positive)*0.05, distance=5) 
918        
919        if len(peaks) > 1:
920            E_levels = E_positive[peaks]
921            spacings = np.diff(E_levels)
922            
923            # Normalize spacings
924            s_mean = np.mean(spacings)
925            s_normalized = spacings / s_mean
926            
927            # Histogram
928            ax.hist(s_normalized, bins=20, density=True, alpha=0.7,
929                   color='blue', edgecolor='black', label='Data')
930            
931            # Theoretical distributions
932            s = np.linspace(0, np.max(s_normalized), 100)
933            
934            # Poisson (integrable systems)
935            poisson = np.exp(-s)
936            ax.plot(s, poisson, 'g--', linewidth=2, label='Poisson (integrable)')
937            
938            # Wigner (chaotic systems)
939            wigner = (np.pi * s / 2) * np.exp(-np.pi * s**2 / 4)
940            ax.plot(s, wigner, 'r-', linewidth=2, label='Wigner (chaotic)')
941            
942            ax.set_xlabel('Normalized spacing s')
943            ax.set_ylabel('P(s)')
944            ax.set_title('Level Spacing Distribution\nIntegrable vs Chaotic', fontweight='bold')
945            ax.legend()
946            ax.grid(True, alpha=0.3)

Comprehensive visualization of symbol geometry

Produces 15 panels showing:

  1. Hamiltonian surface (3D)
  2. Energy level sets (phase space foliation)
  3. Hamiltonian vector field
  4. Group velocity ∂H/∂ξ
  5. Spatial projection (caustics)
  6. Jacobian (focusing measure)
  7. Curvature (focusing tendency)
  8. Energy conservation
  9. Periodic orbits (phase space)
  10. Period-energy diagram
  11. EBK quantization
  12. Trace formula
  13. Semiclassical spectrum
  14. Orbit stability
  15. Level spacing distribution
SymbolVisualizer(geometry: SymbolGeometry)
496    def __init__(self, geometry: SymbolGeometry):
497        """
498        Parameters
499        ----------
500        geometry : SymbolGeometry
501            Initialized geometry engine
502        """
503        self.geo = geometry

Parameters

geometry : SymbolGeometry Initialized geometry engine

geo
def visualize_complete( self, x_range: Tuple[float, float], xi_range: Tuple[float, float], geodesics_params: List[Tuple], E_range: Optional[Tuple[float, float]] = None, hbar: float = 1.0, resolution: int = 100) -> Tuple:
505    def visualize_complete(self, 
506                          x_range: Tuple[float, float],
507                          xi_range: Tuple[float, float],
508                          geodesics_params: List[Tuple],
509                          E_range: Optional[Tuple[float, float]] = None,
510                          hbar: float = 1.0,
511                          resolution: int = 100) -> Tuple:
512        """
513        Create complete geometric atlas
514        
515        Parameters
516        ----------
517        x_range, xi_range : tuple
518            Domain limits
519        geodesics_params : list of tuples
520            Each tuple: (x0, xi0, t_max, color)
521        E_range : tuple, optional
522            Energy range for spectral analysis
523        hbar : float
524            Reduced Planck constant
525        resolution : int
526            Grid resolution
527            
528        Returns
529        -------
530        fig, geodesics, periodic_orbits, spectrum
531        """
532        # Compute grid
533        x_grid = np.linspace(x_range[0], x_range[1], resolution)
534        xi_grid = np.linspace(xi_range[0], xi_range[1], resolution)
535        X, Xi = np.meshgrid(x_grid, xi_grid)
536        
537        # Evaluate Hamiltonian and derivatives on grid
538        grids = self._evaluate_grids(X, Xi)
539        
540        # Compute geodesics
541        geodesics = self._compute_geodesics(geodesics_params)
542        
543        # Find periodic orbits (if E_range specified)
544        periodic_orbits = []
545        spectrum = None
546        if E_range:
547            energies = np.linspace(E_range[0], E_range[1], 8)
548            for E in energies:
549                orbits = self.geo.find_periodic_orbits(E, x_range, xi_range)
550                periodic_orbits.extend(orbits)
551            
552            if periodic_orbits:
553                spectrum = self.geo.semiclassical_spectrum(periodic_orbits, hbar)
554        
555        # Create figure
556        fig = self._create_figure(X, Xi, grids, geodesics, periodic_orbits, spectrum, hbar)
557        
558        return fig, geodesics, periodic_orbits, spectrum

Create complete geometric atlas

Parameters

x_range, xi_range : tuple Domain limits geodesics_params : list of tuples Each tuple: (x0, xi0, t_max, color) E_range : tuple, optional Energy range for spectral analysis hbar : float Reduced Planck constant resolution : int Grid resolution

Returns

fig, geodesics, periodic_orbits, spectrum

class SpectralAnalysis:
 953class SpectralAnalysis:
 954    """
 955    Additional spectral analysis tools
 956    """
 957    
 958    @staticmethod
 959    def weyl_law(energy: float, dimension: int, hbar: float = 1.0) -> float:
 960        """
 961        Weyl's law: asymptotic density of states
 962        
 963        N(E) ~ (1/2πℏ)^d × Vol{H(x,p) ≤ E}
 964        
 965        Parameters
 966        ----------
 967        energy : float
 968            Energy threshold
 969        dimension : int
 970            Phase space dimension
 971        hbar : float
 972            Reduced Planck constant
 973            
 974        Returns
 975        -------
 976        float
 977            Approximate number of states below energy E
 978        """
 979        # Simplified: assumes phase space volume ~ E^d
 980        prefactor = (1 / (2 * np.pi * hbar)) ** dimension
 981        return prefactor * (energy ** dimension)
 982    
 983    @staticmethod
 984    def analyze_integrability(spacings: np.ndarray) -> Dict:
 985        """
 986        Determine if system is integrable or chaotic via level statistics
 987        
 988        Parameters
 989        ----------
 990        spacings : array
 991            Energy level spacings
 992            
 993        Returns
 994        -------
 995        dict
 996            Statistical measures and classification
 997        """
 998        s_mean = np.mean(spacings)
 999        s_normalized = spacings / s_mean
1000        
1001        # Brody parameter (0: Poisson, 1: Wigner)
1002        # Fit P(s) = a s^β exp(-b s^(β+1))
1003        # Simplified: use ratio test
1004        
1005        # <s²>/<s>² ratio
1006        ratio = np.mean(s_normalized**2) / (np.mean(s_normalized)**2)
1007        
1008        # Poisson: ratio ≈ 2
1009        # Wigner: ratio ≈ 1.27
1010        
1011        if ratio > 1.7:
1012            classification = "Integrable (Poisson-like)"
1013        elif ratio < 1.4:
1014            classification = "Chaotic (Wigner-like)"
1015        else:
1016            classification = "Intermediate"
1017        
1018        return {
1019            'ratio': ratio,
1020            'mean_spacing': s_mean,
1021            'std_spacing': np.std(spacings),
1022            'classification': classification
1023        }
1024
1025    @staticmethod
1026    def berry_tabor_formula(periodic_orbits: List[PeriodicOrbit], 
1027                           energy: float, 
1028                           window: float = 1.0) -> float:  # ✅ Fenêtre paramétrable
1029        """
1030        Berry-Tabor formula for integrable systems
1031        
1032        Smoothed density of states from periodic orbits
1033        
1034        Parameters
1035        ----------
1036        periodic_orbits : list
1037            Periodic orbits
1038        energy : float
1039            Energy at which to evaluate density
1040            
1041        Returns
1042        -------
1043        float
1044            Density of states ρ(E)
1045        """
1046        density = 0.0
1047        
1048        for orb in periodic_orbits:
1049            # ✅ Contribution gaussienne lissée
1050            weight = np.exp(-((orb.energy - energy)**2) / (2 * window**2))
1051            density += weight * orb.period / (2 * np.pi)
1052        
1053        return density / (window * np.sqrt(2 * np.pi))

Additional spectral analysis tools

@staticmethod
def weyl_law(energy: float, dimension: int, hbar: float = 1.0) -> float:
958    @staticmethod
959    def weyl_law(energy: float, dimension: int, hbar: float = 1.0) -> float:
960        """
961        Weyl's law: asymptotic density of states
962        
963        N(E) ~ (1/2πℏ)^d × Vol{H(x,p) ≤ E}
964        
965        Parameters
966        ----------
967        energy : float
968            Energy threshold
969        dimension : int
970            Phase space dimension
971        hbar : float
972            Reduced Planck constant
973            
974        Returns
975        -------
976        float
977            Approximate number of states below energy E
978        """
979        # Simplified: assumes phase space volume ~ E^d
980        prefactor = (1 / (2 * np.pi * hbar)) ** dimension
981        return prefactor * (energy ** dimension)

Weyl's law: asymptotic density of states

N(E) ~ (1/2πℏ)^d × Vol{H(x,p) ≤ E}

Parameters

energy : float Energy threshold dimension : int Phase space dimension hbar : float Reduced Planck constant

Returns

float Approximate number of states below energy E

@staticmethod
def analyze_integrability(spacings: numpy.ndarray) -> Dict:
 983    @staticmethod
 984    def analyze_integrability(spacings: np.ndarray) -> Dict:
 985        """
 986        Determine if system is integrable or chaotic via level statistics
 987        
 988        Parameters
 989        ----------
 990        spacings : array
 991            Energy level spacings
 992            
 993        Returns
 994        -------
 995        dict
 996            Statistical measures and classification
 997        """
 998        s_mean = np.mean(spacings)
 999        s_normalized = spacings / s_mean
1000        
1001        # Brody parameter (0: Poisson, 1: Wigner)
1002        # Fit P(s) = a s^β exp(-b s^(β+1))
1003        # Simplified: use ratio test
1004        
1005        # <s²>/<s>² ratio
1006        ratio = np.mean(s_normalized**2) / (np.mean(s_normalized)**2)
1007        
1008        # Poisson: ratio ≈ 2
1009        # Wigner: ratio ≈ 1.27
1010        
1011        if ratio > 1.7:
1012            classification = "Integrable (Poisson-like)"
1013        elif ratio < 1.4:
1014            classification = "Chaotic (Wigner-like)"
1015        else:
1016            classification = "Intermediate"
1017        
1018        return {
1019            'ratio': ratio,
1020            'mean_spacing': s_mean,
1021            'std_spacing': np.std(spacings),
1022            'classification': classification
1023        }

Determine if system is integrable or chaotic via level statistics

Parameters

spacings : array Energy level spacings

Returns

dict Statistical measures and classification

@staticmethod
def berry_tabor_formula( periodic_orbits: List[src.geometry_1d.PeriodicOrbit], energy: float, window: float = 1.0) -> float:
1025    @staticmethod
1026    def berry_tabor_formula(periodic_orbits: List[PeriodicOrbit], 
1027                           energy: float, 
1028                           window: float = 1.0) -> float:  # ✅ Fenêtre paramétrable
1029        """
1030        Berry-Tabor formula for integrable systems
1031        
1032        Smoothed density of states from periodic orbits
1033        
1034        Parameters
1035        ----------
1036        periodic_orbits : list
1037            Periodic orbits
1038        energy : float
1039            Energy at which to evaluate density
1040            
1041        Returns
1042        -------
1043        float
1044            Density of states ρ(E)
1045        """
1046        density = 0.0
1047        
1048        for orb in periodic_orbits:
1049            # ✅ Contribution gaussienne lissée
1050            weight = np.exp(-((orb.energy - energy)**2) / (2 * window**2))
1051            density += weight * orb.period / (2 * np.pi)
1052        
1053        return density / (window * np.sqrt(2 * np.pi))

Berry-Tabor formula for integrable systems

Smoothed density of states from periodic orbits

Parameters

periodic_orbits : list Periodic orbits energy : float Energy at which to evaluate density

Returns

float Density of states ρ(E)

class SymbolGeometry2D:
140class SymbolGeometry2D:
141    """
142    Full geometric and semi-classical analysis of a 2D symbol
143    H(x, y, ξ, η) with 4D phase space and rigorous caustic treatment
144    """
145    def __init__(self, symbol: sp.Expr, 
146                 x_sym: sp.Symbol, y_sym: sp.Symbol,
147                 xi_sym: sp.Symbol, eta_sym: sp.Symbol,
148                 hbar: float = 1.0):
149        """
150        Initialization with complete derivative computation for Jacobian evolution
151        Parameters
152        ----------
153        symbol : sympy expression
154            Hamiltonian H(x, y, ξ, η)
155        x_sym, y_sym : sympy symbols
156            Position coordinates
157        xi_sym, eta_sym : sympy symbols
158            Momentum coordinates
159        hbar : float
160            Reduced Planck constant (for quantum aspects)
161        """
162        self.H_sym = symbol
163        self.x_sym = x_sym
164        self.y_sym = y_sym
165        self.xi_sym = xi_sym
166        self.eta_sym = eta_sym
167        self.hbar = hbar
168            
169        print(f"Initializing 2D geometry engine for H = {self.H_sym} with ℏ = {self.hbar}")
170        # --- First derivatives (Hamiltonian vector field) ---
171        dH_x = sp.diff(self.H_sym, self.x_sym)
172        self.dH_dx_sym = _sanitize(dH_x)
173        dH_y = sp.diff(self.H_sym, self.y_sym)
174        self.dH_dy_sym = _sanitize(dH_y)
175        dH_xi = sp.diff(self.H_sym, self.xi_sym)
176        self.dH_dxi_sym = _sanitize(dH_xi)
177        dH_eta = sp.diff(self.H_sym, self.eta_sym)
178        self.dH_deta_sym = _sanitize(dH_eta)
179
180        # --- Second derivatives for variational equations ---
181        d2H_x2 = sp.diff(self.dH_dx_sym, self.x_sym)
182        self.d2H_dx2_sym = _sanitize(d2H_x2)
183        d2H_y2 = sp.diff(self.dH_dy_sym, self.y_sym)
184        self.d2H_dy2_sym = _sanitize(d2H_y2)
185        d2H_xi2 = sp.diff(self.dH_dxi_sym, self.xi_sym)
186        self.d2H_dxi2_sym = _sanitize(d2H_xi2)
187        d2H_eta2 = sp.diff(self.dH_deta_sym, self.eta_sym)
188        self.d2H_deta2_sym = _sanitize(d2H_eta2)
189        d2H_xy = sp.diff(self.dH_dx_sym, self.y_sym)
190        self.d2H_dxdy_sym = _sanitize(d2H_xy)
191        d2H_xxi = sp.diff(self.dH_dx_sym, self.xi_sym)
192        self.d2H_dxdxi_sym = _sanitize(d2H_xxi)
193        d2H_xeta = sp.diff(self.dH_dx_sym, self.eta_sym)
194        self.d2H_dxdeta_sym = _sanitize(d2H_xeta)
195        d2H_yxi = sp.diff(self.dH_dy_sym, self.xi_sym)
196        self.d2H_dydxi_sym = _sanitize(d2H_yxi)
197        d2H_yeta = sp.diff(self.dH_dy_sym, self.eta_sym)
198        self.d2H_dyeta_sym = _sanitize(d2H_yeta)
199        d2H_xieta = sp.diff(self.dH_dxi_sym, self.eta_sym)
200        self.d2H_dxideta_sym = _sanitize(d2H_xieta)
201        # --- Hessian for variational equations ---
202        self.Hessian = sp.Matrix([
203            [self.d2H_dx2_sym, self.d2H_dxdy_sym, self.d2H_dxdxi_sym, self.d2H_dxdeta_sym],
204            [self.d2H_dxdy_sym, self.d2H_dy2_sym, self.d2H_dydxi_sym, self.d2H_dyeta_sym],
205            [self.d2H_dxdxi_sym, self.d2H_dydxi_sym, self.d2H_dxi2_sym, self.d2H_dxideta_sym],
206            [self.d2H_dxdeta_sym, self.d2H_dyeta_sym, self.d2H_dxideta_sym, self.d2H_deta2_sym]
207        ])
208
209        # --- Convert to numerical functions ---
210        self._lambdify_functions()
211  
212    def _safe_lambdify(self, args: tuple, expr: sp.Expr) -> Callable:
213        """Safe conversion of sympy expressions to numerical functions"""
214        if isinstance(expr, (int, float, sp.Integer, sp.Float)):
215            const_val = float(expr)
216            return lambda x, y, xi, eta: np.full_like(x, const_val)
217        try:
218            return sp.lambdify(args, expr, modules=['numpy', 'scipy'])
219        except Exception as e:
220            print(f"Warning: lambdify failed for {expr}. Error: {e}")
221            return lambda x, y, xi, eta: np.full_like(x, np.nan)
222
223    def _lambdify_functions(self):
224        """Convert all symbolic expressions to numerical functions"""
225        args = (self.x_sym, self.y_sym, self.xi_sym, self.eta_sym)
226        self.H_num = self._safe_lambdify(args, self.H_sym)
227        self.dH_dx_num = self._safe_lambdify(args, self.dH_dx_sym)
228        self.dH_dy_num = self._safe_lambdify(args, self.dH_dy_sym)
229        self.dH_dxi_num = self._safe_lambdify(args, self.dH_dxi_sym)
230        self.dH_deta_num = self._safe_lambdify(args, self.dH_deta_sym)
231        # Hessian functions
232        self.second_derivs_funcs = []
233        for i in range(4):
234            row_funcs = []
235            for j in range(4):
236                row_funcs.append(self._safe_lambdify(args, self.Hessian[i,j]))
237            self.second_derivs_funcs.append(row_funcs)
238    
239    def _hamiltonian_system_augmented(self, t: float, z: np.ndarray) -> np.ndarray:
240        """
241        Augmented Hamiltonian system with variational equations for Jacobian evolution
242        State vector z = [x, y, xi, eta, J11, J12, ..., J44] (20 dimensions)
243        """
244        # Extract position and momentum
245        x, y, xi, eta = z[0:4]
246        # Extract Jacobian matrix (4x4)
247        J = z[4:].reshape((4, 4))
248        try:
249            # Hamilton's equations
250            dx = float(self.dH_dxi_num(x, y, xi, eta))
251            dy = float(self.dH_deta_num(x, y, xi, eta))
252            dxi = float(-self.dH_dx_num(x, y, xi, eta))
253            deta = float(-self.dH_dy_num(x, y, xi, eta))
254            # Evaluate numerical Hessian
255            Hessian_num = np.zeros((4, 4))
256            for i in range(4):
257                for j in range(4):
258                    Hessian_num[i, j] = float(self.second_derivs_funcs[i][j](x, y, xi, eta))
259            # Symplectic matrix J0
260            J0 = np.array([
261                [0, 0, 1, 0],
262                [0, 0, 0, 1],
263                [-1, 0, 0, 0],
264                [0, -1, 0, 0]
265            ])
266            # Variational equations: dJ/dt = J @ (J0 @ Hessian)
267            dJ_dt = J @ (J0 @ Hessian_num)
268            # Build derivative vector
269            dz = np.zeros(20)
270            dz[0:4] = [dx, dy, dxi, deta]
271            dz[4:] = dJ_dt.flatten()
272            return dz
273        except Exception as e:
274            print(f"Integration error at t={t}, z={z}: {e}")
275            return np.zeros(20)
276    
277    def compute_geodesic(self, x0: float, y0: float, 
278                        xi0: float, eta0: float,
279                        t_max: float, n_points: int = 500) -> Geodesic2D:
280        """
281        Compute a geodesic with full Jacobian evolution for caustic detection
282        Parameters
283        ----------
284        x0, y0 : float
285            Initial position
286        xi0, eta0 : float
287            Initial momentum
288        t_max : float
289            Final time
290        n_points : int
291            Number of sampling points
292        Returns
293        -------
294        Geodesic2D
295            Structure containing trajectory and caustic analysis
296        """
297        # Initial condition: position, momentum + identity Jacobian
298        z0 = np.zeros(20)
299        z0[0:4] = [x0, y0, xi0, eta0]
300        z0[4:] = np.eye(4).flatten()
301        t_eval = np.linspace(0, t_max, n_points)
302        sol = solve_ivp(
303            self._hamiltonian_system_augmented,
304            [0, t_max], z0, t_eval=t_eval,
305            method='DOP853', rtol=1e-9, atol=1e-12
306        )
307        if not sol.success:
308            print(f"Warning: Integration failed for ({x0}, {y0}, {xi0}, {eta0})")
309        # Extract trajectory data
310        x_traj = sol.y[0]
311        y_traj = sol.y[1]
312        xi_traj = sol.y[2]
313        eta_traj = sol.y[3]
314        # Evaluate energy
315        H_vals = self.H_num(x_traj, y_traj, xi_traj, eta_traj)
316        # Extract and reshape Jacobian matrices
317        J_mats = np.zeros((n_points, 4, 4))
318        for i in range(n_points):
319            J_mats[i] = sol.y[4:, i].reshape((4, 4))
320        # Submatrix for caustic detection: ∂(x,y)/∂(ξ₀,η₀)
321        caustic_matrix = J_mats[:, 0:2, 2:4]
322        # Determinant for caustic detection
323        det_caustic = np.zeros(n_points)
324        for i in range(n_points):
325            det_caustic[i] = np.linalg.det(caustic_matrix[i])
326        # Detect caustic indices (sign change)
327        caustic_indices = np.where(np.diff(np.sign(det_caustic)))[0]
328        return Geodesic2D(
329            t=sol.t,
330            x=x_traj,
331            y=y_traj,
332            xi=xi_traj,
333            eta=eta_traj,
334            H=H_vals,
335            J_full=J_mats,
336            det_caustic=det_caustic,
337            caustic_indices=caustic_indices
338        )
339    
340    def find_periodic_orbits_2d(self, energy: float,
341                               x_range: Tuple[float, float],
342                               y_range: Tuple[float, float],
343                               xi_range: Tuple[float, float],
344                               eta_range: Tuple[float, float],
345                               n_attempts: int = 30) -> List[PeriodicOrbit2D]:
346        """
347        Search for periodic orbits with Maslov index computation
348        """
349        orbits = []
350        # Sample configuration space
351        n_samples = int(np.sqrt(n_attempts))
352        x_samples = np.linspace(x_range[0], x_range[1], n_samples)
353        y_samples = np.linspace(y_range[0], y_range[1], n_samples)
354        for x0 in x_samples:
355            for y0 in y_samples:
356                # Test different momentum directions
357                angles = np.linspace(0, 2*np.pi, 8)
358                for angle in angles:
359                    for r in np.linspace(0.5, 3, 3):
360                        xi0_guess = r * np.cos(angle)
361                        eta0_guess = r * np.sin(angle)
362                        try:
363                            # Energy check
364                            E_test = self.H_num(x0, y0, xi0_guess, eta0_guess)
365                            if abs(E_test - energy) > 0.5:
366                                continue
367                            # Compute geodesic
368                            geo = self.compute_geodesic(x0, y0, xi0_guess, eta0_guess, 15, 1500)
369                            # Search for return points
370                            distances = np.sqrt((geo.x - x0)**2 + (geo.y - y0)**2 +
371                                              (geo.xi - xi0_guess)**2 + (geo.eta - eta0_guess)**2)
372                            minima = []
373                            for i in range(10, len(distances)-10):
374                                if (distances[i] < distances[i-1] and
375                                    distances[i] < distances[i+1] and
376                                    distances[i] < 0.05):
377                                    minima.append(i)
378                            if minima:
379                                idx = minima[0]
380                                period = geo.t[idx]
381                                if period > 0.2 and distances[idx] < 0.05:
382                                    # Compute action
383                                    x_cyc = geo.x[:idx+1]
384                                    y_cyc = geo.y[:idx+1]
385                                    xi_cyc = geo.xi[:idx+1]
386                                    eta_cyc = geo.eta[:idx+1]
387                                    t_cyc = geo.t[:idx+1]
388                                    dx_dt = np.gradient(x_cyc, t_cyc)
389                                    dy_dt = np.gradient(y_cyc, t_cyc)
390                                    action = np.trapz(xi_cyc * dx_dt + eta_cyc * dy_dt, t_cyc)
391                                    # Compute Maslov index (number of caustic crossings)
392                                    maslov_index = len([i for i in geo.caustic_indices if i < idx])
393                                    # Compute stability
394                                    stab1 = self._compute_stability_2d(x0, y0, xi0_guess, eta0_guess, period)
395                                    orbits.append(PeriodicOrbit2D(
396                                        x0=x0, y0=y0,
397                                        xi0=xi0_guess, eta0=eta0_guess,
398                                        period=period,
399                                        action=action,
400                                        energy=energy,
401                                        stability_1=stab1,
402                                        stability_2=0.0,
403                                        x_cycle=x_cyc,
404                                        y_cycle=y_cyc,
405                                        xi_cycle=xi_cyc,
406                                        eta_cycle=eta_cyc,
407                                        t_cycle=t_cyc,
408                                        maslov_index=maslov_index
409                                    ))
410                        except Exception as e:
411                            continue
412        return self._remove_duplicate_orbits_2d(orbits)
413    
414    def _compute_stability_2d(self, x0, y0, xi0, eta0, T):
415        """Compute the largest Lyapunov exponent"""
416        def linearized(t, z):
417            x, y, xi, eta, dx, dy, dxi, deta = z
418            try:
419                vx = float(self.dH_dxi_num(x, y, xi, eta))
420                vy = float(self.dH_deta_num(x, y, xi, eta))
421                vxi = float(-self.dH_dx_num(x, y, xi, eta))
422                veta = float(-self.dH_dy_num(x, y, xi, eta))
423                # Linearization (simplified)
424                A13 = float(self.second_derivs_funcs[2][0](x, y, xi, eta))
425                A24 = float(self.second_derivs_funcs[3][1](x, y, xi, eta))
426                ddx = A13 * dxi
427                ddy = A24 * deta
428                ddxi = 0
429                ddeta = 0
430                return [vx, vy, vxi, veta, ddx, ddy, ddxi, ddeta]
431            except:
432                return [0]*8
433        eps = 1e-6
434        z0 = [x0, y0, xi0, eta0, eps, 0, 0, 0]
435        sol = solve_ivp(linearized, [0, T], z0, method='DOP853', rtol=1e-10)
436        if sol.success and len(sol.y[4]) > 0:
437            pert = np.sqrt(sol.y[4][-1]**2 + sol.y[5][-1]**2)
438            return np.log(pert / eps) / T
439        return np.nan
440    
441    def _remove_duplicate_orbits_2d(self, orbits):
442        """Remove duplicate periodic orbits"""
443        unique = []
444        for orb in orbits:
445            is_dup = False
446            for u_orb in unique:
447                if (abs(orb.period - u_orb.period) < 0.2 and
448                    abs(orb.action - u_orb.action) < 0.2):
449                    is_dup = True
450                    break
451            if not is_dup:
452                unique.append(orb)
453        return unique
454    
455    def detect_caustic_structures(self, geodesics: List[Geodesic2D], 
456                                 t_fixed: float) -> List[CausticStructure]:
457        """
458        Advanced caustic structure detection with classification
459        """
460        caustic_points = []
461        for geo in geodesics:
462            # Find closest time to t_fixed
463            idx = np.argmin(np.abs(geo.t - t_fixed))
464            # Check if near a caustic
465            if abs(geo.det_caustic[idx]) < 0.1:
466                # Classify caustic type
467                caustic_type = self._classify_caustic(geo, idx)
468                # Compute singularity strength
469                strength = 1.0 / (abs(geo.det_caustic[idx]) + 0.01)
470                caustic_points.append({
471                    'x': geo.x[idx],
472                    'y': geo.y[idx],
473                    'energy': geo.energy,
474                    'type': caustic_type,
475                    'strength': strength
476                })
477        if len(caustic_points) < 3:
478            return []
479        # Cluster points into caustic structures
480        caustic_structures = self._cluster_caustic_points(caustic_points, t_fixed)
481        return caustic_structures
482    
483    def _classify_caustic(self, geo: Geodesic2D, idx: int) -> str:
484        """
485        Caustic classification according to catastrophe theory
486        """
487        # Compute curvature near caustic point
488        window = 10
489        start = max(0, idx - window)
490        end = min(len(geo.t), idx + window + 1)
491        if end - start < 5:
492            return 'fold'
493        # Curvature approximation
494        x_window = geo.x[start:end]
495        y_window = geo.y[start:end]
496        dx = np.gradient(x_window)
497        dy = np.gradient(y_window)
498        ddx = np.gradient(dx)
499        ddy = np.gradient(dy)
500        with np.errstate(divide='ignore', invalid='ignore'):
501            curvature = np.abs(dx * ddy - dy * ddx) / (dx**2 + dy**2)**1.5
502        curvature = np.nan_to_num(curvature, nan=0.0, posinf=0.0, neginf=0.0)
503        # Detect cusp points (high curvature)
504        if np.max(curvature) > 2.0 * np.mean(curvature):
505            return 'cusp'
506        return 'fold'
507    
508    def _cluster_caustic_points(self, points: List[dict], t_fixed: float) -> List[CausticStructure]:
509        """Group caustic points into coherent structures"""
510        if not points:
511            return []
512        # Extract coordinates
513        coords = np.array([[p['x'], p['y']] for p in points])
514        # Simple proximity-based clustering
515        clusters = []
516        visited = set()
517        for i, point in enumerate(points):
518            if i in visited:
519                continue
520            # New cluster
521            cluster = [point]
522            visited.add(i)
523            # Find nearby points
524            for j, other in enumerate(points):
525                if j in visited:
526                    continue
527                dist = np.sqrt((point['x'] - other['x'])**2 + (point['y'] - other['y'])**2)
528                if dist < 0.5:  # Distance threshold
529                    cluster.append(other)
530                    visited.add(j)
531            # Create caustic structure
532            xs = np.array([p['x'] for p in cluster])
533            ys = np.array([p['y'] for p in cluster])
534            types = [p['type'] for p in cluster]
535            strengths = [p['strength'] for p in cluster]
536            # Majority type
537            type_counts = {}
538            for t in types:
539                type_counts[t] = type_counts.get(t, 0) + 1
540            dominant_type = max(type_counts.items(), key=lambda x: x[1])[0]
541            # Maslov index (approximation)
542            maslov_index = 1 if dominant_type == 'fold' else 2
543            clusters.append(CausticStructure(
544                x=xs,
545                y=ys,
546                t=t_fixed,
547                energy=cluster[0]['energy'],
548                type=dominant_type,
549                maslov_index=maslov_index,
550                strength=np.mean(strengths)
551            ))
552        return clusters
553    
554    def compute_phase_space_volume(self, E_max: float, x_range: tuple, y_range: tuple,
555                                 xi_range: tuple, eta_range: tuple, 
556                                 n_samples: int = 200000) -> float:
557        """Monte Carlo estimation of phase space volume for H ≤ E_max"""
558        # Generate random samples
559        x_samples = np.random.uniform(x_range[0], x_range[1], n_samples)
560        y_samples = np.random.uniform(y_range[0], y_range[1], n_samples)
561        xi_samples = np.random.uniform(xi_range[0], xi_range[1], n_samples)
562        eta_samples = np.random.uniform(eta_range[0], eta_range[1], n_samples)
563        # Evaluate Hamiltonian
564        H_vals = self.H_num(x_samples, y_samples, xi_samples, eta_samples)
565        # Count points where H ≤ E_max
566        volume_ratio = np.mean(H_vals <= E_max)
567        # Total phase space volume
568        total_volume = ((x_range[1]-x_range[0]) * (y_range[1]-y_range[0]) * 
569                       (xi_range[1]-xi_range[0]) * (eta_range[1]-eta_range[0]))
570        return volume_ratio * total_volume

Full geometric and semi-classical analysis of a 2D symbol H(x, y, ξ, η) with 4D phase space and rigorous caustic treatment

SymbolGeometry2D( symbol: sympy.core.expr.Expr, x_sym: sympy.core.symbol.Symbol, y_sym: sympy.core.symbol.Symbol, xi_sym: sympy.core.symbol.Symbol, eta_sym: sympy.core.symbol.Symbol, hbar: float = 1.0)
145    def __init__(self, symbol: sp.Expr, 
146                 x_sym: sp.Symbol, y_sym: sp.Symbol,
147                 xi_sym: sp.Symbol, eta_sym: sp.Symbol,
148                 hbar: float = 1.0):
149        """
150        Initialization with complete derivative computation for Jacobian evolution
151        Parameters
152        ----------
153        symbol : sympy expression
154            Hamiltonian H(x, y, ξ, η)
155        x_sym, y_sym : sympy symbols
156            Position coordinates
157        xi_sym, eta_sym : sympy symbols
158            Momentum coordinates
159        hbar : float
160            Reduced Planck constant (for quantum aspects)
161        """
162        self.H_sym = symbol
163        self.x_sym = x_sym
164        self.y_sym = y_sym
165        self.xi_sym = xi_sym
166        self.eta_sym = eta_sym
167        self.hbar = hbar
168            
169        print(f"Initializing 2D geometry engine for H = {self.H_sym} with ℏ = {self.hbar}")
170        # --- First derivatives (Hamiltonian vector field) ---
171        dH_x = sp.diff(self.H_sym, self.x_sym)
172        self.dH_dx_sym = _sanitize(dH_x)
173        dH_y = sp.diff(self.H_sym, self.y_sym)
174        self.dH_dy_sym = _sanitize(dH_y)
175        dH_xi = sp.diff(self.H_sym, self.xi_sym)
176        self.dH_dxi_sym = _sanitize(dH_xi)
177        dH_eta = sp.diff(self.H_sym, self.eta_sym)
178        self.dH_deta_sym = _sanitize(dH_eta)
179
180        # --- Second derivatives for variational equations ---
181        d2H_x2 = sp.diff(self.dH_dx_sym, self.x_sym)
182        self.d2H_dx2_sym = _sanitize(d2H_x2)
183        d2H_y2 = sp.diff(self.dH_dy_sym, self.y_sym)
184        self.d2H_dy2_sym = _sanitize(d2H_y2)
185        d2H_xi2 = sp.diff(self.dH_dxi_sym, self.xi_sym)
186        self.d2H_dxi2_sym = _sanitize(d2H_xi2)
187        d2H_eta2 = sp.diff(self.dH_deta_sym, self.eta_sym)
188        self.d2H_deta2_sym = _sanitize(d2H_eta2)
189        d2H_xy = sp.diff(self.dH_dx_sym, self.y_sym)
190        self.d2H_dxdy_sym = _sanitize(d2H_xy)
191        d2H_xxi = sp.diff(self.dH_dx_sym, self.xi_sym)
192        self.d2H_dxdxi_sym = _sanitize(d2H_xxi)
193        d2H_xeta = sp.diff(self.dH_dx_sym, self.eta_sym)
194        self.d2H_dxdeta_sym = _sanitize(d2H_xeta)
195        d2H_yxi = sp.diff(self.dH_dy_sym, self.xi_sym)
196        self.d2H_dydxi_sym = _sanitize(d2H_yxi)
197        d2H_yeta = sp.diff(self.dH_dy_sym, self.eta_sym)
198        self.d2H_dyeta_sym = _sanitize(d2H_yeta)
199        d2H_xieta = sp.diff(self.dH_dxi_sym, self.eta_sym)
200        self.d2H_dxideta_sym = _sanitize(d2H_xieta)
201        # --- Hessian for variational equations ---
202        self.Hessian = sp.Matrix([
203            [self.d2H_dx2_sym, self.d2H_dxdy_sym, self.d2H_dxdxi_sym, self.d2H_dxdeta_sym],
204            [self.d2H_dxdy_sym, self.d2H_dy2_sym, self.d2H_dydxi_sym, self.d2H_dyeta_sym],
205            [self.d2H_dxdxi_sym, self.d2H_dydxi_sym, self.d2H_dxi2_sym, self.d2H_dxideta_sym],
206            [self.d2H_dxdeta_sym, self.d2H_dyeta_sym, self.d2H_dxideta_sym, self.d2H_deta2_sym]
207        ])
208
209        # --- Convert to numerical functions ---
210        self._lambdify_functions()

Initialization with complete derivative computation for Jacobian evolution

Parameters

symbol : sympy expression Hamiltonian H(x, y, ξ, η) x_sym, y_sym : sympy symbols Position coordinates xi_sym, eta_sym : sympy symbols Momentum coordinates hbar : float Reduced Planck constant (for quantum aspects)

H_sym
x_sym
y_sym
xi_sym
eta_sym
hbar
dH_dx_sym
dH_dy_sym
dH_dxi_sym
dH_deta_sym
d2H_dx2_sym
d2H_dy2_sym
d2H_dxi2_sym
d2H_deta2_sym
d2H_dxdy_sym
d2H_dxdxi_sym
d2H_dxdeta_sym
d2H_dydxi_sym
d2H_dyeta_sym
d2H_dxideta_sym
Hessian
def compute_geodesic( self, x0: float, y0: float, xi0: float, eta0: float, t_max: float, n_points: int = 500) -> src.geometry_2d.Geodesic2D:
277    def compute_geodesic(self, x0: float, y0: float, 
278                        xi0: float, eta0: float,
279                        t_max: float, n_points: int = 500) -> Geodesic2D:
280        """
281        Compute a geodesic with full Jacobian evolution for caustic detection
282        Parameters
283        ----------
284        x0, y0 : float
285            Initial position
286        xi0, eta0 : float
287            Initial momentum
288        t_max : float
289            Final time
290        n_points : int
291            Number of sampling points
292        Returns
293        -------
294        Geodesic2D
295            Structure containing trajectory and caustic analysis
296        """
297        # Initial condition: position, momentum + identity Jacobian
298        z0 = np.zeros(20)
299        z0[0:4] = [x0, y0, xi0, eta0]
300        z0[4:] = np.eye(4).flatten()
301        t_eval = np.linspace(0, t_max, n_points)
302        sol = solve_ivp(
303            self._hamiltonian_system_augmented,
304            [0, t_max], z0, t_eval=t_eval,
305            method='DOP853', rtol=1e-9, atol=1e-12
306        )
307        if not sol.success:
308            print(f"Warning: Integration failed for ({x0}, {y0}, {xi0}, {eta0})")
309        # Extract trajectory data
310        x_traj = sol.y[0]
311        y_traj = sol.y[1]
312        xi_traj = sol.y[2]
313        eta_traj = sol.y[3]
314        # Evaluate energy
315        H_vals = self.H_num(x_traj, y_traj, xi_traj, eta_traj)
316        # Extract and reshape Jacobian matrices
317        J_mats = np.zeros((n_points, 4, 4))
318        for i in range(n_points):
319            J_mats[i] = sol.y[4:, i].reshape((4, 4))
320        # Submatrix for caustic detection: ∂(x,y)/∂(ξ₀,η₀)
321        caustic_matrix = J_mats[:, 0:2, 2:4]
322        # Determinant for caustic detection
323        det_caustic = np.zeros(n_points)
324        for i in range(n_points):
325            det_caustic[i] = np.linalg.det(caustic_matrix[i])
326        # Detect caustic indices (sign change)
327        caustic_indices = np.where(np.diff(np.sign(det_caustic)))[0]
328        return Geodesic2D(
329            t=sol.t,
330            x=x_traj,
331            y=y_traj,
332            xi=xi_traj,
333            eta=eta_traj,
334            H=H_vals,
335            J_full=J_mats,
336            det_caustic=det_caustic,
337            caustic_indices=caustic_indices
338        )

Compute a geodesic with full Jacobian evolution for caustic detection

Parameters

x0, y0 : float Initial position xi0, eta0 : float Initial momentum t_max : float Final time n_points : int Number of sampling points

Returns

Geodesic2D Structure containing trajectory and caustic analysis

def find_periodic_orbits_2d( self, energy: float, x_range: Tuple[float, float], y_range: Tuple[float, float], xi_range: Tuple[float, float], eta_range: Tuple[float, float], n_attempts: int = 30) -> List[src.geometry_2d.PeriodicOrbit2D]:
340    def find_periodic_orbits_2d(self, energy: float,
341                               x_range: Tuple[float, float],
342                               y_range: Tuple[float, float],
343                               xi_range: Tuple[float, float],
344                               eta_range: Tuple[float, float],
345                               n_attempts: int = 30) -> List[PeriodicOrbit2D]:
346        """
347        Search for periodic orbits with Maslov index computation
348        """
349        orbits = []
350        # Sample configuration space
351        n_samples = int(np.sqrt(n_attempts))
352        x_samples = np.linspace(x_range[0], x_range[1], n_samples)
353        y_samples = np.linspace(y_range[0], y_range[1], n_samples)
354        for x0 in x_samples:
355            for y0 in y_samples:
356                # Test different momentum directions
357                angles = np.linspace(0, 2*np.pi, 8)
358                for angle in angles:
359                    for r in np.linspace(0.5, 3, 3):
360                        xi0_guess = r * np.cos(angle)
361                        eta0_guess = r * np.sin(angle)
362                        try:
363                            # Energy check
364                            E_test = self.H_num(x0, y0, xi0_guess, eta0_guess)
365                            if abs(E_test - energy) > 0.5:
366                                continue
367                            # Compute geodesic
368                            geo = self.compute_geodesic(x0, y0, xi0_guess, eta0_guess, 15, 1500)
369                            # Search for return points
370                            distances = np.sqrt((geo.x - x0)**2 + (geo.y - y0)**2 +
371                                              (geo.xi - xi0_guess)**2 + (geo.eta - eta0_guess)**2)
372                            minima = []
373                            for i in range(10, len(distances)-10):
374                                if (distances[i] < distances[i-1] and
375                                    distances[i] < distances[i+1] and
376                                    distances[i] < 0.05):
377                                    minima.append(i)
378                            if minima:
379                                idx = minima[0]
380                                period = geo.t[idx]
381                                if period > 0.2 and distances[idx] < 0.05:
382                                    # Compute action
383                                    x_cyc = geo.x[:idx+1]
384                                    y_cyc = geo.y[:idx+1]
385                                    xi_cyc = geo.xi[:idx+1]
386                                    eta_cyc = geo.eta[:idx+1]
387                                    t_cyc = geo.t[:idx+1]
388                                    dx_dt = np.gradient(x_cyc, t_cyc)
389                                    dy_dt = np.gradient(y_cyc, t_cyc)
390                                    action = np.trapz(xi_cyc * dx_dt + eta_cyc * dy_dt, t_cyc)
391                                    # Compute Maslov index (number of caustic crossings)
392                                    maslov_index = len([i for i in geo.caustic_indices if i < idx])
393                                    # Compute stability
394                                    stab1 = self._compute_stability_2d(x0, y0, xi0_guess, eta0_guess, period)
395                                    orbits.append(PeriodicOrbit2D(
396                                        x0=x0, y0=y0,
397                                        xi0=xi0_guess, eta0=eta0_guess,
398                                        period=period,
399                                        action=action,
400                                        energy=energy,
401                                        stability_1=stab1,
402                                        stability_2=0.0,
403                                        x_cycle=x_cyc,
404                                        y_cycle=y_cyc,
405                                        xi_cycle=xi_cyc,
406                                        eta_cycle=eta_cyc,
407                                        t_cycle=t_cyc,
408                                        maslov_index=maslov_index
409                                    ))
410                        except Exception as e:
411                            continue
412        return self._remove_duplicate_orbits_2d(orbits)

Search for periodic orbits with Maslov index computation

def detect_caustic_structures( self, geodesics: List[src.geometry_2d.Geodesic2D], t_fixed: float) -> List[src.geometry_2d.CausticStructure]:
455    def detect_caustic_structures(self, geodesics: List[Geodesic2D], 
456                                 t_fixed: float) -> List[CausticStructure]:
457        """
458        Advanced caustic structure detection with classification
459        """
460        caustic_points = []
461        for geo in geodesics:
462            # Find closest time to t_fixed
463            idx = np.argmin(np.abs(geo.t - t_fixed))
464            # Check if near a caustic
465            if abs(geo.det_caustic[idx]) < 0.1:
466                # Classify caustic type
467                caustic_type = self._classify_caustic(geo, idx)
468                # Compute singularity strength
469                strength = 1.0 / (abs(geo.det_caustic[idx]) + 0.01)
470                caustic_points.append({
471                    'x': geo.x[idx],
472                    'y': geo.y[idx],
473                    'energy': geo.energy,
474                    'type': caustic_type,
475                    'strength': strength
476                })
477        if len(caustic_points) < 3:
478            return []
479        # Cluster points into caustic structures
480        caustic_structures = self._cluster_caustic_points(caustic_points, t_fixed)
481        return caustic_structures

Advanced caustic structure detection with classification

def compute_phase_space_volume( self, E_max: float, x_range: tuple, y_range: tuple, xi_range: tuple, eta_range: tuple, n_samples: int = 200000) -> float:
554    def compute_phase_space_volume(self, E_max: float, x_range: tuple, y_range: tuple,
555                                 xi_range: tuple, eta_range: tuple, 
556                                 n_samples: int = 200000) -> float:
557        """Monte Carlo estimation of phase space volume for H ≤ E_max"""
558        # Generate random samples
559        x_samples = np.random.uniform(x_range[0], x_range[1], n_samples)
560        y_samples = np.random.uniform(y_range[0], y_range[1], n_samples)
561        xi_samples = np.random.uniform(xi_range[0], xi_range[1], n_samples)
562        eta_samples = np.random.uniform(eta_range[0], eta_range[1], n_samples)
563        # Evaluate Hamiltonian
564        H_vals = self.H_num(x_samples, y_samples, xi_samples, eta_samples)
565        # Count points where H ≤ E_max
566        volume_ratio = np.mean(H_vals <= E_max)
567        # Total phase space volume
568        total_volume = ((x_range[1]-x_range[0]) * (y_range[1]-y_range[0]) * 
569                       (xi_range[1]-xi_range[0]) * (eta_range[1]-eta_range[0]))
570        return volume_ratio * total_volume

Monte Carlo estimation of phase space volume for H ≤ E_max

class SymbolVisualizer2D:
 575class SymbolVisualizer2D:
 576    """
 577    Complete visualization combining geometric and physical aspects
 578    """
 579    def __init__(self, geometry: SymbolGeometry2D):
 580        self.geo = geometry
 581
 582    def visualize_complete(self,
 583                          x_range: Tuple[float, float],
 584                          y_range: Tuple[float, float],
 585                          xi_range: Tuple[float, float],
 586                          eta_range: Tuple[float, float],
 587                          geodesics_params: List[Tuple],
 588                          E_range: Optional[Tuple[float, float]] = None,
 589                          hbar: float = 1.0,
 590                          resolution: int = 50) -> Tuple:
 591        """
 592        Create a complete 18-panel visualization combining geometry and physics
 593        Parameters
 594        ----------
 595        x_range, y_range : tuple
 596            Configuration space domain
 597        xi_range, eta_range : tuple
 598            Momentum space domain
 599        geodesics_params : list
 600            Geodesic parameters: (x0, y0, xi0, eta0, t_max, color)
 601        E_range : tuple, optional
 602            Energy interval for spectral analysis
 603        hbar : float
 604            Reduced Planck constant
 605        resolution : int
 606            Grid resolution
 607        Returns
 608        -------
 609        fig, geodesics, periodic_orbits, caustics
 610        """
 611        # Compute geodesics with caustic detection
 612        geodesics = self._compute_geodesics(geodesics_params)
 613        # Search for periodic orbits
 614        periodic_orbits = []
 615        if E_range:
 616            energies = np.linspace(E_range[0], E_range[1], 5)
 617            for E in energies:
 618                orbits = self.geo.find_periodic_orbits_2d(
 619                    E, x_range, y_range, xi_range, eta_range, n_attempts=20
 620                )
 621                periodic_orbits.extend(orbits)
 622        # Detect caustic structures
 623        caustics = []
 624        if geodesics:
 625            t_samples = np.linspace(0, geodesics[0].t[-1], 5)
 626            for t in t_samples:
 627                caustics.extend(self.geo.detect_caustic_structures(geodesics, t))
 628        # Create full figure
 629        fig = self._create_complete_figure(
 630            E_range, x_range, y_range, xi_range, eta_range,
 631            geodesics, periodic_orbits, caustics, hbar, resolution
 632        )
 633        return fig, geodesics, periodic_orbits, caustics
 634    
 635    def _compute_geodesics(self, params):
 636        """Compute geodesics with caustic detection"""
 637        geodesics = []
 638        for p in params:
 639            x0, y0, xi0, eta0, t_max = p[:5]
 640            geo = self.geo.compute_geodesic(x0, y0, xi0, eta0, t_max)
 641            geo.color = p[5] if len(p) > 5 else 'blue'
 642            geodesics.append(geo)
 643        return geodesics
 644
 645    
 646    def _create_complete_figure(self, E_range, x_range, y_range, xi_range, eta_range,
 647                               geodesics, periodic_orbits, caustics, hbar, resolution):
 648        """Creates an adaptive multi-panel figure: only relevant panels are displayed."""
 649        
 650        # --- List of panels with explicit call signatures ---
 651        panels_to_plot = []
 652    
 653        # Always safe to plot if data exists
 654        if geodesics:
 655            panels_to_plot.append(lambda ax_spec: self._plot_energy_surface_2d(fig, ax_spec, x_range, y_range, geodesics, resolution))
 656            panels_to_plot.append(lambda ax_spec: self._plot_configuration_space(fig, ax_spec, geodesics, caustics))
 657            panels_to_plot.append(lambda ax_spec: self._plot_phase_projection_x(fig, ax_spec, geodesics))
 658            panels_to_plot.append(lambda ax_spec: self._plot_phase_projection_y(fig, ax_spec, geodesics))
 659            panels_to_plot.append(lambda ax_spec: self._plot_momentum_space(fig, ax_spec, geodesics))
 660            panels_to_plot.append(lambda ax_spec: self._plot_vector_field_2d(fig, ax_spec, x_range, y_range, geodesics, resolution))
 661            panels_to_plot.append(lambda ax_spec: self._plot_group_velocity_2d(fig, ax_spec, x_range, y_range, geodesics, resolution))
 662            panels_to_plot.append(lambda ax_spec: self._plot_caustic_curves_2d(fig, ax_spec, geodesics, caustics))
 663            panels_to_plot.append(lambda ax_spec: self._plot_jacobian_evolution(fig, ax_spec, geodesics))
 664            panels_to_plot.append(lambda ax_spec: self._plot_energy_conservation_2d(fig, ax_spec, geodesics))
 665            panels_to_plot.append(lambda ax_spec: self._plot_poincare_x(fig, ax_spec, geodesics))
 666            panels_to_plot.append(lambda ax_spec: self._plot_poincare_y(fig, ax_spec, geodesics))
 667            panels_to_plot.append(lambda ax_spec: self._plot_caustic_network(fig, ax_spec, x_range, y_range, geodesics))
 668    
 669        if geodesics and caustics:
 670            pass  # already handled above
 671    
 672        if periodic_orbits:
 673            panels_to_plot.append(lambda ax_spec: self._plot_periodic_orbits_3d(fig, ax_spec, periodic_orbits))
 674            panels_to_plot.append(lambda ax_spec: self._plot_action_energy_2d(fig, ax_spec, periodic_orbits))
 675            panels_to_plot.append(lambda ax_spec: self._plot_torus_quantization(fig, ax_spec, periodic_orbits, hbar))
 676            if len(periodic_orbits) > 2:
 677                panels_to_plot.append(lambda ax_spec: self._plot_level_spacing_2d(fig, ax_spec, periodic_orbits))
 678    
 679        if periodic_orbits and E_range:
 680            panels_to_plot.append(lambda ax_spec: self._plot_spectral_density_with_caustics(fig, ax_spec, periodic_orbits, E_range))
 681    
 682        # Always plot Maslov (demo)
 683        panels_to_plot.append(lambda ax_spec: self._plot_maslov_index_phase_shifts(fig, ax_spec, geodesics, caustics))
 684    
 685        if E_range:
 686            panels_to_plot.append(lambda ax_spec: self._plot_phase_space_volume(fig, ax_spec, E_range, x_range, y_range, xi_range, eta_range))
 687    
 688        # --- Handle empty case ---
 689        if not panels_to_plot:
 690            fig, ax = plt.subplots(figsize=(10, 6))
 691            ax.text(0.5, 0.5, "No panels to display for this Hamiltonian.",
 692                    ha='center', va='center', fontsize=16, transform=ax.transAxes)
 693            ax.set_axis_off()
 694            return fig
 695    
 696        # --- Dynamic layout ---
 697        n = len(panels_to_plot)
 698        if n <= 5:
 699            cols, rows = n, 1
 700        elif n <= 10:
 701            cols, rows = 5, 2
 702        elif n <= 15:
 703            cols, rows = 5, 3
 704        else:
 705            cols, rows = 5, (n + 4) // 5
 706    
 707        figsize = (4.8 * cols, 4.0 * rows)
 708        fig = plt.figure(figsize=figsize)
 709        gs = GridSpec(rows, cols, figure=fig, hspace=0.5, wspace=0.3)
 710        plt.suptitle(f'Geometric and Semiclassical Atlas: H = {self.geo.H_sym} (ℏ={hbar})',
 711                     fontsize=18, fontweight='bold', y=0.98)
 712    
 713        # --- Plot all panels ---
 714        for idx, plot_cmd in enumerate(panels_to_plot):
 715            if idx >= rows * cols:
 716                break
 717            row = idx // cols
 718            col = idx % cols
 719            subplot_spec = gs[row, col]
 720            try:
 721                plot_cmd(subplot_spec)
 722            except Exception as e:
 723                ax = fig.add_subplot(subplot_spec)
 724                ax.text(0.5, 0.5, f"[Error]\n{type(e).__name__}", ha='center', va='center', color='red')
 725                ax.set_axis_off()
 726    
 727        plt.tight_layout(rect=[0, 0.02, 1, 0.95])
 728        return fig
 729
 730    # ======== DETAILED VISUALIZATION METHODS ========
 731    def _plot_energy_surface_2d(self, fig, subplot_spec, x_range, y_range, geodesics, res):
 732        """Energy surface H(x,y) at fixed (ξ,η)"""
 733        ax = fig.add_subplot(subplot_spec, projection='3d')
 734        x = np.linspace(x_range[0], x_range[1], res)
 735        y = np.linspace(y_range[0], y_range[1], res)
 736        X, Y = np.meshgrid(x, y)
 737        # Evaluate at reference momentum
 738        xi_ref, eta_ref = 1.0, 1.0
 739        Z = np.zeros_like(X)
 740        for i in range(X.shape[0]):
 741            for j in range(X.shape[1]):
 742                try:
 743                    Z[i,j] = self.geo.H_num(X[i,j], Y[i,j], xi_ref, eta_ref)
 744                except:
 745                    Z[i,j] = np.nan
 746        # Surface with transparency to see geodesics
 747        ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.6, edgecolor='none')
 748        # Geodesics on the surface
 749        for geo in geodesics[:5]:
 750            H_geo = np.array([self.geo.H_num(geo.x[i], geo.y[i], xi_ref, eta_ref)
 751                             for i in range(len(geo.t))])
 752            color = getattr(geo, 'color', 'red')
 753            ax.plot(geo.x, geo.y, H_geo, color=color, linewidth=2.5)
 754        ax.set_xlabel('x')
 755        ax.set_ylabel('y')
 756        ax.set_zlabel('H')
 757        ax.set_title('Energy Surface\nH(x,y,ξ₀,η₀)', fontweight='bold', fontsize=10)
 758        ax.view_init(elev=25, azim=-45)
 759    
 760    def _plot_configuration_space(self, fig, subplot_spec, geodesics, caustics):
 761        """Configuration space (x,y) with trajectories and caustics"""
 762        ax = fig.add_subplot(subplot_spec)
 763        
 764        # Trajectories - use thinner lines and lighter colors for better visibility
 765        for geo in geodesics:
 766            color = getattr(geo, 'color', 'blue')
 767            ax.plot(geo.x, geo.y, color=color, linewidth=1.5, alpha=0.7, zorder=5)
 768            ax.scatter([geo.x[0]], [geo.y[0]], color=color, s=80, 
 769                      marker='o', edgecolors='black', linewidths=1.5, zorder=10)
 770        
 771        # Caustic points on trajectories - keep as stars but reduce size slightly
 772        for geo in geodesics:
 773            caust_x, caust_y = geo.caustic_points
 774            if len(caust_x) > 0:
 775                ax.scatter(caust_x, caust_y, c='red', s=80, marker='*',  # Reduced from 120
 776                          edgecolors='darkred', linewidths=1.0, zorder=15,
 777                          label='Caustic points')
 778        
 779        # Caustic structures - use smaller, more subtle markers
 780        for caust in caustics:
 781            color_map = {'fold': 'red', 'cusp': 'magenta', 'swallowtail': 'orange'}
 782            color = color_map.get(caust.type, 'red')
 783            # Use a small circle or dot instead of a large X
 784            marker = 'o'  # You can also try '.' for even smaller dots
 785            # Reduce size significantly and increase transparency
 786            size = 30  # Fixed size for clarity, or use: max(15, min(50, 80 * caust.strength / 2))
 787            alpha_val = 0.5  # More transparent to avoid obscuring trajectories
 788            
 789            ax.scatter(caust.x, caust.y, c=color, s=size, marker=marker,
 790                      edgecolors='none',  # Remove edge for cleaner look
 791                      linewidths=0, alpha=alpha_val, zorder=12,  # zorder between traj and points
 792                      label=f'Caustic {caust.type} (μ={caust.maslov_index})')
 793        
 794        ax.set_xlabel('x')
 795        ax.set_ylabel('y')
 796        ax.set_title('Configuration Space\n★ = caustics', fontweight='bold', fontsize=10)
 797        ax.grid(True, alpha=0.3)
 798        ax.set_aspect('equal')
 799        
 800        # Legend without duplicates
 801        handles, labels = ax.get_legend_handles_labels()
 802        by_label = dict(zip(labels, handles))
 803        if by_label:
 804            ax.legend(by_label.values(), by_label.keys(), fontsize=8, loc='upper right')
 805    
 806    def _plot_jacobian_evolution(self, fig, subplot_spec, geodesics):
 807        """Evolution of Jacobian determinant with caustic detection"""
 808        ax = fig.add_subplot(subplot_spec)
 809        for geo in geodesics:
 810            color = getattr(geo, 'color', 'blue')
 811            ax.plot(geo.t, geo.det_caustic, color=color, linewidth=2.5, alpha=0.9,
 812                   label=f'E={geo.energy:.2f}')
 813            # Mark caustic points
 814            for idx in geo.caustic_indices:
 815                ax.scatter(geo.t[idx], geo.det_caustic[idx], s=100, marker='*',
 816                          color='red', edgecolor='darkred', zorder=10)
 817        ax.axhline(0, color='red', linestyle='--', linewidth=2, alpha=0.7)
 818        ax.set_xlabel('Time t')
 819        ax.set_ylabel('det(∂(x,y)/∂(ξ₀,η₀))')
 820        ax.set_title('Jacobian Determinant\nZeros = caustics', fontweight='bold', fontsize=10)
 821        ax.grid(True, alpha=0.3)
 822        ax.legend(fontsize=8)
 823    
 824    def _plot_maslov_index_phase_shifts(self, fig, subplot_spec, geodesics, caustics):
 825        """Visualization of phase shifts due to Maslov index"""
 826        ax = fig.add_subplot(subplot_spec)
 827        # Simulate wavefunction crossing caustics
 828        x_demo = np.linspace(-4, 4, 1000)
 829        k = 2.0  # Wavenumber
 830        # Free wavefunction (before caustic)
 831        psi_free = np.exp(1j * k * x_demo**2 / 2)
 832        # Simulate phase shifts at caustics
 833        caustic_positions = [-2.0, 0.0, 2.0]  # Caustic positions
 834        maslov_indices = [1, 2, 1]  # Maslov index for each caustic
 835        psi_with_shifts = np.zeros_like(psi_free, dtype=complex)
 836        current_phase = 0.0
 837        for i, x in enumerate(x_demo):
 838            # Check if crossing a caustic
 839            for j, caust_x in enumerate(caustic_positions):
 840                if abs(x - caust_x) < 0.05:
 841                    current_phase -= maslov_indices[j] * np.pi / 2
 842            psi_with_shifts[i] = psi_free[i] * np.exp(1j * current_phase)
 843        # Plot real parts
 844        ax.plot(x_demo, np.real(psi_free), 'b-', alpha=0.8, linewidth=2, 
 845                label='Re[ψ] before caustics')
 846        ax.plot(x_demo, np.real(psi_with_shifts), 'r-', alpha=0.8, linewidth=2, 
 847                label='Re[ψ] after caustics')
 848        # Mark caustic positions
 849        for i, caust_x in enumerate(caustic_positions):
 850            ax.axvline(caust_x, color='k', linestyle='--', alpha=0.7,
 851                      label=f'Caustic μ={maslov_indices[i]}')
 852        ax.set_xlabel('Position x')
 853        ax.set_ylabel('Re[ψ(x)]')
 854        ax.set_title('Maslov Index\nPhase shifts at caustics', fontweight='bold', fontsize=10)
 855        ax.set_ylim(-1.5, 1.5)
 856        ax.grid(True, alpha=0.3)
 857        ax.legend(fontsize=8, loc='upper right')
 858    
 859    def _plot_spectral_density_with_caustics(self, fig, subplot_spec, periodic_orbits, E_range):
 860        """Spectral density with caustic corrections"""
 861        ax = fig.add_subplot(subplot_spec)
 862        if not periodic_orbits:
 863            ax.text(0.5, 0.5, 'No periodic orbits', 
 864                   ha='center', va='center', transform=ax.transAxes)
 865            return
 866        # Sort orbits by energy
 867        orbits_sorted = sorted(periodic_orbits, key=lambda x: x.energy)
 868        energies = np.array([orb.energy for orb in orbits_sorted])
 869        periods = np.array([orb.period for orb in orbits_sorted])
 870        # Compute state density ρ(E) = T(E)/(2π) for integrable systems
 871        if len(energies) > 1:
 872            dE = np.diff(energies)
 873            dT = np.diff(periods)
 874            rho_E = np.zeros_like(energies)
 875            rho_E[1:-1] = (periods[2:] - periods[:-2]) / (energies[2:] - energies[:-2])
 876            if len(rho_E) > 2:
 877                rho_E[0] = (periods[1] - periods[0]) / (energies[1] - energies[0])
 878                rho_E[-1] = (periods[-1] - periods[-2]) / (energies[-1] - energies[-2])
 879            rho_E = np.maximum(rho_E, 0)  # Avoid negative values
 880            # Caustic correction (oscillatory terms)
 881            rho_osc = np.zeros_like(rho_E)
 882            for orb in orbits_sorted:
 883                # Amplitude depending on Maslov index
 884                amp = 0.3 * np.exp(-orb.maslov_index/2) * orb.period
 885                phase = orb.action / self.geo.hbar - np.pi * orb.maslov_index / 2
 886                idx = np.argmin(np.abs(energies - orb.energy))
 887                if 0 <= idx < len(rho_osc):
 888                    rho_osc[idx] += amp * np.cos(phase)
 889            # Smooth curve
 890            E_fine = np.linspace(E_range[0], E_range[1], 500)
 891            from scipy.interpolate import interp1d
 892            try:
 893                interp_rho = interp1d(energies, rho_E, kind='cubic', fill_value="extrapolate")
 894                interp_osc = interp1d(energies, rho_osc, kind='cubic', fill_value="extrapolate")
 895                rho_smooth = np.maximum(0, interp_rho(E_fine))
 896                rho_osc_smooth = interp_osc(E_fine)
 897                # Plot components
 898                ax.plot(E_fine, rho_smooth, 'k-', linewidth=2.5, 
 899                       label='Smooth (Weyl)')
 900                ax.plot(E_fine, rho_smooth + rho_osc_smooth, 'b-', linewidth=2,
 901                       label='Total with caustics')
 902                ax.fill_between(E_fine, rho_smooth, rho_smooth + rho_osc_smooth, 
 903                               where=rho_osc_smooth>0, color='#ff9999', alpha=0.4,
 904                               label='Caustic corrections')
 905            except:
 906                ax.plot(energies, rho_E, 'b-o', linewidth=2, label='State density ρ(E)')
 907        ax.set_xlabel('Energy E')
 908        ax.set_ylabel('ρ(E)')
 909        ax.set_title('Spectral Density\nwith caustic corrections', fontweight='bold', fontsize=10)
 910        ax.grid(True, alpha=0.3)
 911        ax.legend(fontsize=8)
 912    
 913    def _plot_phase_space_volume(self, fig, subplot_spec, E_range, x_range, y_range, xi_range, eta_range):
 914        """Phase space volume via Monte Carlo"""
 915        ax = fig.add_subplot(subplot_spec)
 916        # Compute volume for different energies
 917        E_vals = np.linspace(E_range[0], E_range[1], 8)
 918        volumes = []
 919        print("Computing phase space volume (Monte Carlo)...")
 920        for E in E_vals:
 921            vol = self.geo.compute_phase_space_volume(E, x_range, y_range, xi_range, eta_range, n_samples=50000)
 922            volumes.append(vol)
 923            print(f"  E={E:.2f}, Volume={vol:.4f}")
 924        # Weyl law: N(E) ~ Vol/(2πℏ)²
 925        d = 2  # Dimension
 926        weyl_constant = (2 * np.pi * self.geo.hbar) ** d
 927        N_weyl = np.array(volumes) / weyl_constant
 928        ax.plot(E_vals, N_weyl, 'b-o', linewidth=2.5, markersize=8, 
 929                label=f'Weyl law: N(E) ~ Vol/(2πℏ)²', color='#1f77b4')
 930        # Conceptual caustic correction
 931        if len(E_vals) > 3:
 932            oscillation_freq = 5 / (E_range[1] - E_range[0])
 933            correction = 0.15 * N_weyl * np.sin(2 * np.pi * oscillation_freq * (E_vals - E_vals[0]) + 0.7)
 934            N_corrected = N_weyl + correction
 935            from scipy.ndimage import gaussian_filter1d
 936            N_corrected_smooth = gaussian_filter1d(N_corrected, sigma=1.0)
 937            ax.plot(E_vals, N_corrected_smooth, 'r--', linewidth=2, 
 938                   label="With caustic corrections", alpha=0.9)
 939        ax.set_xlabel('Energy E')
 940        ax.set_ylabel('N(E) (Number of states)')
 941        ax.set_title('Phase Space Volume\n(Monte Carlo)', fontweight='bold', fontsize=10)
 942        ax.grid(True, alpha=0.3)
 943        ax.legend(fontsize=8)
 944    
 945    def _plot_caustic_network(self, fig, subplot_spec, x_range, y_range, geodesics):
 946        """Caustic network with multiple initial conditions"""
 947        ax = fig.add_subplot(subplot_spec)
 948        if not geodesics:
 949            ax.text(0.5, 0.5, 'No geodesics', 
 950                   ha='center', va='center', transform=ax.transAxes)
 951            return
 952        # Use first geodesic as reference
 953        E_ref = geodesics[0].energy
 954        t_max = geodesics[0].t[-1]
 955        # Generate trajectory family
 956        n_family = 15
 957        x0_vals = np.linspace(x_range[0], x_range[1], n_family)
 958        caustic_points = []
 959        for x0 in x0_vals:
 960            try:
 961                # Solve for y0, xi0, eta0 keeping energy constant
 962                def energy_eq(vars):
 963                    y_val, xi_val, eta_val = vars
 964                    return self.geo.H_num(x0, y_val, xi_val, eta_val) - E_ref
 965                # Use initial values of first geodesic as guess
 966                y0_guess = geodesics[0].y[0]
 967                xi0_guess = geodesics[0].xi[0]
 968                eta0_guess = geodesics[0].eta[0]
 969                sol = fsolve(energy_eq, [y0_guess, xi0_guess, eta0_guess])
 970                if np.all(np.isfinite(sol)):
 971                    y0_new, xi0_new, eta0_new = sol
 972                    # Compute trajectory
 973                    geo = self.geo.compute_geodesic(x0, y0_new, xi0_new, eta0_new, t_max, n_points=300)
 974                    # Plot trajectory
 975                    ax.plot(geo.x, geo.y, color='blue', alpha=0.3, linewidth=1)
 976                    # Collect caustic points
 977                    caust_x, caust_y = geo.caustic_points
 978                    for i in range(len(caust_x)):
 979                        caustic_points.append((caust_x[i], caust_y[i]))
 980            except Exception as e:
 981                continue
 982        # Plot caustic points
 983        if caustic_points:
 984            caustic_points = np.array(caustic_points)
 985            ax.scatter(caustic_points[:, 0], caustic_points[:, 1], 
 986                      s=30, c='red', alpha=0.8, edgecolor='none',
 987                      label='Caustic points')
 988        ax.set_xlabel('x')
 989        ax.set_ylabel('y')
 990        ax.set_title('Caustic Network\n(Multiple initial conditions)', fontweight='bold', fontsize=10)
 991        ax.set_xlim(x_range)
 992        ax.set_ylim(y_range)
 993        ax.grid(True, alpha=0.3)
 994        ax.legend(fontsize=8)
 995    
 996    # ======== STANDARD VISUALIZATION METHODS (similar to v1) ========
 997    # Following methods are similar to v1 but enhanced
 998    # to integrate caustics and new data structures
 999    def _plot_phase_projection_x(self, fig, subplot_spec, geodesics):
1000        """Phase space projection (x,ξ)"""
1001        ax = fig.add_subplot(subplot_spec)
1002        for geo in geodesics:
1003            color = getattr(geo, 'color', 'blue')
1004            ax.plot(geo.x, geo.xi, color=color, linewidth=2, alpha=0.8)
1005            ax.scatter([geo.x[0]], [geo.xi[0]], color=color, s=80,
1006                      marker='o', edgecolors='black', linewidths=1.5)
1007        ax.set_xlabel('x')
1008        ax.set_ylabel('ξ')
1009        ax.set_title('Phase Space (x,ξ)', fontweight='bold', fontsize=10)
1010        ax.grid(True, alpha=0.3)
1011    
1012    def _plot_phase_projection_y(self, fig, subplot_spec, geodesics):
1013        """Phase space projection (y,η)"""
1014        ax = fig.add_subplot(subplot_spec)
1015        for geo in geodesics:
1016            color = getattr(geo, 'color', 'blue')
1017            ax.plot(geo.y, geo.eta, color=color, linewidth=2, alpha=0.8)
1018            ax.scatter([geo.y[0]], [geo.eta[0]], color=color, s=80,
1019                      marker='o', edgecolors='black', linewidths=1.5)
1020        ax.set_xlabel('y')
1021        ax.set_ylabel('η')
1022        ax.set_title('Phase Space (y,η)', fontweight='bold', fontsize=10)
1023        ax.grid(True, alpha=0.3)
1024    
1025    def _plot_momentum_space(self, fig, subplot_spec, geodesics):
1026        """Momentum space (ξ,η)"""
1027        ax = fig.add_subplot(subplot_spec)
1028        for geo in geodesics:
1029            color = getattr(geo, 'color', 'blue')
1030            ax.plot(geo.xi, geo.eta, color=color, linewidth=2, alpha=0.8)
1031            ax.scatter([geo.xi[0]], [geo.eta[0]], color=color, s=80,
1032                      marker='o', edgecolors='black', linewidths=1.5)
1033        ax.set_xlabel('ξ')
1034        ax.set_ylabel('η')
1035        ax.set_title('Momentum Space\n(ξ,η)', fontweight='bold', fontsize=10)
1036        ax.grid(True, alpha=0.3)
1037        ax.set_aspect('equal')
1038    
1039    def _plot_vector_field_2d(self, fig, subplot_spec, x_range, y_range, geodesics, res):
1040        """Vector field in configuration space"""
1041        ax = fig.add_subplot(subplot_spec)
1042        x = np.linspace(x_range[0], x_range[1], res//2)
1043        y = np.linspace(y_range[0], y_range[1], res//2)
1044        X, Y = np.meshgrid(x, y)
1045        # Evaluate vector field at reference momentum
1046        xi_ref, eta_ref = 1.0, 1.0
1047        VX = np.zeros_like(X)
1048        VY = np.zeros_like(Y)
1049        for i in range(X.shape[0]):
1050            for j in range(X.shape[1]):
1051                try:
1052                    VX[i,j] = self.geo.dH_dxi_num(X[i,j], Y[i,j], xi_ref, eta_ref)
1053                    VY[i,j] = self.geo.dH_deta_num(X[i,j], Y[i,j], xi_ref, eta_ref)
1054                except:
1055                    VX[i,j] = VY[i,j] = np.nan
1056        # Magnitude for coloring
1057        magnitude = np.sqrt(VX**2 + VY**2)
1058        magnitude[magnitude == 0] = 1
1059        # Normalized vector field
1060        ax.quiver(X, Y, VX/magnitude, VY/magnitude, magnitude, 
1061                 cmap='plasma', alpha=0.7, scale=30)
1062        # Overlay geodesics
1063        for geo in geodesics[:5]:
1064            color = getattr(geo, 'color', 'white')
1065            ax.plot(geo.x, geo.y, color=color, linewidth=2.5, alpha=0.9)
1066        ax.set_xlabel('x')
1067        ax.set_ylabel('y')
1068        ax.set_title('Vector Field\nFlow in configuration space', fontweight='bold', fontsize=10)
1069        ax.set_aspect('equal')
1070    
1071    def _plot_group_velocity_2d(self, fig, subplot_spec, x_range, y_range, geodesics, res):
1072        """Group velocity magnitude |∇_p H|"""
1073        ax = fig.add_subplot(subplot_spec)
1074        x = np.linspace(x_range[0], x_range[1], res)
1075        y = np.linspace(y_range[0], y_range[1], res)
1076        X, Y = np.meshgrid(x, y)
1077        # Group velocity at reference momentum
1078        xi_ref, eta_ref = 1.0, 1.0
1079        V_mag = np.zeros_like(X)
1080        for i in range(X.shape[0]):
1081            for j in range(X.shape[1]):
1082                try:
1083                    vx = self.geo.dH_dxi_num(X[i,j], Y[i,j], xi_ref, eta_ref)
1084                    vy = self.geo.dH_deta_num(X[i,j], Y[i,j], xi_ref, eta_ref)
1085                    V_mag[i,j] = np.sqrt(vx**2 + vy**2)
1086                except:
1087                    V_mag[i,j] = np.nan
1088        # Heatmap
1089        im = ax.contourf(X, Y, V_mag, levels=20, cmap='hot')
1090        plt.colorbar(im, ax=ax, label='|v_g|')
1091        # Geodesics
1092        for geo in geodesics[:5]:
1093            ax.plot(geo.x, geo.y, 'cyan', linewidth=2, alpha=0.8)
1094        ax.set_xlabel('x')
1095        ax.set_ylabel('y')
1096        ax.set_title('Group Velocity\n|∇_p H|', fontweight='bold', fontsize=10)
1097        ax.set_aspect('equal')
1098    
1099    def _plot_caustic_curves_2d(self, fig, subplot_spec, geodesics, caustics):
1100        """Caustic curves in (x,y) space"""
1101        ax = fig.add_subplot(subplot_spec)
1102        # All geodesics
1103        for geo in geodesics:
1104            color = getattr(geo, 'color', 'lightblue')
1105            ax.plot(geo.x, geo.y, color=color, linewidth=1.5, alpha=0.5)
1106            # Caustic points on each geodesic
1107            caust_x, caust_y = geo.caustic_points
1108            if len(caust_x) > 0:
1109                ax.scatter(caust_x, caust_y, c='red', s=80, marker='*', 
1110                          edgecolors='darkred', linewidths=1.5, zorder=10)
1111        # Complete caustic structures
1112        for caust in caustics:
1113            color_map = {'fold': 'red', 'cusp': 'magenta', 'swallowtail': 'orange'}
1114            color = color_map.get(caust.type, 'red')
1115            # If enough points, plot smoothed curve
1116            if len(caust.x) > 3:
1117                ax.plot(caust.x, caust.y, color=color, linewidth=3, 
1118                       label=f'Caustic {caust.type} (μ={caust.maslov_index})')
1119            else:
1120                ax.scatter(caust.x, caust.y, c=color, s=100, marker='X',
1121                          edgecolors='black', linewidths=1.5,
1122                          label=f'Caustic {caust.type}')
1123        ax.set_xlabel('x')
1124        ax.set_ylabel('y')
1125        ax.set_title('Caustic Curves\n★ = points on geodesics', fontweight='bold', fontsize=10)
1126        ax.grid(True, alpha=0.3)
1127        ax.set_aspect('equal')
1128        # Legend without duplicates
1129        handles, labels = ax.get_legend_handles_labels()
1130        by_label = dict(zip(labels, handles))
1131        if by_label:
1132            ax.legend(by_label.values(), by_label.keys(), fontsize=8)
1133    
1134    def _plot_energy_conservation_2d(self, fig, subplot_spec, geodesics):
1135        """Energy conservation verification"""
1136        ax = fig.add_subplot(subplot_spec)
1137        for geo in geodesics:
1138            color = getattr(geo, 'color', 'blue')
1139            H_var = (geo.H - geo.H[0]) / (np.abs(geo.H[0]) + 1e-10)
1140            ax.semilogy(geo.t, np.abs(H_var) + 1e-16,
1141                       color=color, linewidth=2, label=f'E={geo.H[0]:.2f}')
1142        ax.set_xlabel('Time t')
1143        ax.set_ylabel('|ΔH/H₀|')
1144        ax.set_title('Energy Conservation\nNumerical quality', fontweight='bold', fontsize=10)
1145        ax.legend(fontsize=8)
1146        ax.grid(True, alpha=0.3, which='both')
1147    
1148    def _plot_poincare_x(self, fig, subplot_spec, geodesics):
1149        """Poincaré section (x,ξ) at y=0"""
1150        ax = fig.add_subplot(subplot_spec)
1151        for geo in geodesics:
1152            # Find y=0 crossings
1153            crossings_x = []
1154            crossings_xi = []
1155            for i in range(len(geo.y)-1):
1156                if geo.y[i] * geo.y[i+1] < 0:  # Sign change
1157                    alpha = -geo.y[i] / (geo.y[i+1] - geo.y[i])
1158                    x_cross = geo.x[i] + alpha * (geo.x[i+1] - geo.x[i])
1159                    xi_cross = geo.xi[i] + alpha * (geo.xi[i+1] - geo.xi[i])
1160                    crossings_x.append(x_cross)
1161                    crossings_xi.append(xi_cross)
1162            if crossings_x:
1163                color = getattr(geo, 'color', 'blue')
1164                ax.scatter(crossings_x, crossings_xi, c=color, s=50, alpha=0.7)
1165        ax.set_xlabel('x')
1166        ax.set_ylabel('ξ')
1167        ax.set_title('Poincaré Section\n(x,ξ) at y=0', fontweight='bold', fontsize=10)
1168        ax.grid(True, alpha=0.3)
1169    
1170    def _plot_poincare_y(self, fig, subplot_spec, geodesics):
1171        """Poincaré section (y,η) at x=0"""
1172        ax = fig.add_subplot(subplot_spec)
1173        for geo in geodesics:
1174            # Find x=0 crossings
1175            crossings_y = []
1176            crossings_eta = []
1177            for i in range(len(geo.x)-1):
1178                if geo.x[i] * geo.x[i+1] < 0:
1179                    alpha = -geo.x[i] / (geo.x[i+1] - geo.x[i])
1180                    y_cross = geo.y[i] + alpha * (geo.y[i+1] - geo.y[i])
1181                    eta_cross = geo.eta[i] + alpha * (geo.eta[i+1] - geo.eta[i])
1182                    crossings_y.append(y_cross)
1183                    crossings_eta.append(eta_cross)
1184            if crossings_y:
1185                color = getattr(geo, 'color', 'blue')
1186                ax.scatter(crossings_y, crossings_eta, c=color, s=50, alpha=0.7)
1187        ax.set_xlabel('y')
1188        ax.set_ylabel('η')
1189        ax.set_title('Poincaré Section\n(y,η) at x=0', fontweight='bold', fontsize=10)
1190        ax.grid(True, alpha=0.3)
1191    
1192    def _plot_periodic_orbits_3d(self, fig, subplot_spec, periodic_orbits):
1193        """Periodic orbits in 3D (x,y,t)"""
1194        ax = fig.add_subplot(subplot_spec, projection='3d')
1195        colors = plt.cm.rainbow(np.linspace(0, 1, min(10, len(periodic_orbits))))
1196        for idx, orb in enumerate(periodic_orbits[:10]):  # Limit for clarity
1197            ax.plot(orb.x_cycle, orb.y_cycle, orb.t_cycle,
1198                   color=colors[idx], linewidth=2.5, alpha=0.8)
1199            ax.scatter([orb.x0], [orb.y0], [0], color=colors[idx],
1200                      s=100, marker='o', edgecolors='black', linewidths=2)
1201        ax.set_xlabel('x')
1202        ax.set_ylabel('y')
1203        ax.set_zlabel('t')
1204        ax.set_title('Periodic Orbits\nSpace-time view', fontweight='bold', fontsize=10)
1205    
1206    def _plot_action_energy_2d(self, fig, subplot_spec, periodic_orbits):
1207        """Action vs Energy"""
1208        ax = fig.add_subplot(subplot_spec)
1209        E_orb = [orb.energy for orb in periodic_orbits]
1210        S_orb = [orb.action for orb in periodic_orbits]
1211        T_orb = [orb.period for orb in periodic_orbits]
1212        scatter = ax.scatter(E_orb, S_orb, c=T_orb, s=150,
1213                           cmap='plasma', edgecolors='black', linewidths=1.5)
1214        plt.colorbar(scatter, ax=ax, label='Period T')
1215        ax.set_xlabel('Energy E')
1216        ax.set_ylabel('Action S')
1217        ax.set_title('Action-Energy\nS(E)', fontweight='bold', fontsize=10)
1218        ax.grid(True, alpha=0.3)
1219    
1220    def _plot_torus_quantization(self, fig, subplot_spec, periodic_orbits, hbar):
1221        """Torus quantization (KAM theory)"""
1222        ax = fig.add_subplot(subplot_spec)
1223        E_orb = [orb.energy for orb in periodic_orbits]
1224        S_orb = [orb.action for orb in periodic_orbits]
1225        scatter = ax.scatter(E_orb, S_orb, s=150, c='blue',
1226                           edgecolors='black', linewidths=1.5, label='Orbits')
1227        # EBK quantization for 2D: S_i = 2πℏ(n_i + α_i)
1228        # Simplified for one dimension
1229        E_max = max(E_orb) if E_orb else 10
1230        for n in range(20):
1231            S_quant = 2 * np.pi * hbar * (n + 0.5)
1232            if S_quant < max(S_orb) if S_orb else 10:
1233                ax.axhline(S_quant, color='red', linestyle='--', alpha=0.3)
1234                ax.text(min(E_orb) if E_orb else 0, S_quant, 
1235                       f'n={n}', fontsize=7, color='red')
1236        ax.set_xlabel('Energy E')
1237        ax.set_ylabel('Action S')
1238        ax.set_title('Torus Quantization\nKAM theory', fontweight='bold', fontsize=10)
1239        ax.legend(fontsize=8)
1240        ax.grid(True, alpha=0.3)
1241    
1242    def _plot_level_spacing_2d(self, fig, subplot_spec, periodic_orbits):
1243        """Level spacing distribution"""
1244        ax = fig.add_subplot(subplot_spec)
1245        # Extract unique energies
1246        energies = sorted(set(orb.energy for orb in periodic_orbits))
1247        if len(energies) > 2:
1248            spacings = np.diff(energies)
1249            # Normalize
1250            s_mean = np.mean(spacings)
1251            s_norm = spacings / s_mean
1252            # Histogram
1253            ax.hist(s_norm, bins=15, density=True, alpha=0.7,
1254                   color='blue', edgecolor='black', label='Data')
1255            # Theoretical curves
1256            s = np.linspace(0, np.max(s_norm), 100)
1257            # Poisson (integrable systems)
1258            poisson = np.exp(-s)
1259            ax.plot(s, poisson, 'g--', linewidth=2, label='Poisson (Integrable)')
1260            # Wigner (chaotic systems)
1261            wigner = (np.pi * s / 2) * np.exp(-np.pi * s**2 / 4)
1262            ax.plot(s, wigner, 'r-', linewidth=2, label='Wigner (Chaotic)')
1263            ax.set_xlabel('Normalized spacing s')
1264            ax.set_ylabel('P(s)')
1265            ax.set_title('Level Spacing\nIntegrable vs Chaotic', fontweight='bold', fontsize=10)
1266            ax.legend(fontsize=8)
1267            ax.grid(True, alpha=0.3)

Complete visualization combining geometric and physical aspects

SymbolVisualizer2D(geometry: SymbolGeometry2D)
579    def __init__(self, geometry: SymbolGeometry2D):
580        self.geo = geometry
geo
def visualize_complete( self, x_range: Tuple[float, float], y_range: Tuple[float, float], xi_range: Tuple[float, float], eta_range: Tuple[float, float], geodesics_params: List[Tuple], E_range: Optional[Tuple[float, float]] = None, hbar: float = 1.0, resolution: int = 50) -> Tuple:
582    def visualize_complete(self,
583                          x_range: Tuple[float, float],
584                          y_range: Tuple[float, float],
585                          xi_range: Tuple[float, float],
586                          eta_range: Tuple[float, float],
587                          geodesics_params: List[Tuple],
588                          E_range: Optional[Tuple[float, float]] = None,
589                          hbar: float = 1.0,
590                          resolution: int = 50) -> Tuple:
591        """
592        Create a complete 18-panel visualization combining geometry and physics
593        Parameters
594        ----------
595        x_range, y_range : tuple
596            Configuration space domain
597        xi_range, eta_range : tuple
598            Momentum space domain
599        geodesics_params : list
600            Geodesic parameters: (x0, y0, xi0, eta0, t_max, color)
601        E_range : tuple, optional
602            Energy interval for spectral analysis
603        hbar : float
604            Reduced Planck constant
605        resolution : int
606            Grid resolution
607        Returns
608        -------
609        fig, geodesics, periodic_orbits, caustics
610        """
611        # Compute geodesics with caustic detection
612        geodesics = self._compute_geodesics(geodesics_params)
613        # Search for periodic orbits
614        periodic_orbits = []
615        if E_range:
616            energies = np.linspace(E_range[0], E_range[1], 5)
617            for E in energies:
618                orbits = self.geo.find_periodic_orbits_2d(
619                    E, x_range, y_range, xi_range, eta_range, n_attempts=20
620                )
621                periodic_orbits.extend(orbits)
622        # Detect caustic structures
623        caustics = []
624        if geodesics:
625            t_samples = np.linspace(0, geodesics[0].t[-1], 5)
626            for t in t_samples:
627                caustics.extend(self.geo.detect_caustic_structures(geodesics, t))
628        # Create full figure
629        fig = self._create_complete_figure(
630            E_range, x_range, y_range, xi_range, eta_range,
631            geodesics, periodic_orbits, caustics, hbar, resolution
632        )
633        return fig, geodesics, periodic_orbits, caustics

Create a complete 18-panel visualization combining geometry and physics

Parameters

x_range, y_range : tuple Configuration space domain xi_range, eta_range : tuple Momentum space domain geodesics_params : list Geodesic parameters: (x0, y0, xi0, eta0, t_max, color) E_range : tuple, optional Energy interval for spectral analysis hbar : float Reduced Planck constant resolution : int Grid resolution

Returns

fig, geodesics, periodic_orbits, caustics

class Utilities2D:
1356class Utilities2D:
1357    """Additional analysis tools for 2D systems"""
1358    @staticmethod
1359    def compute_winding_number(geo: Geodesic2D) -> float:
1360        """
1361        Compute winding number around origin
1362        """
1363        angles = np.arctan2(geo.y, geo.x)
1364        angles_unwrapped = np.unwrap(angles)
1365        winding = (angles_unwrapped[-1] - angles_unwrapped[0]) / (2 * np.pi)
1366        return winding
1367
1368    @staticmethod
1369    def compute_rotation_numbers(geo: Geodesic2D) -> Tuple[float, float]:
1370        """
1371        Compute rotation numbers (ω_x, ω_y)
1372        """
1373        theta_x = np.arctan2(geo.xi, geo.x)
1374        theta_y = np.arctan2(geo.eta, geo.y)
1375        theta_x = np.unwrap(theta_x)
1376        theta_y = np.unwrap(theta_y)
1377        omega_x = (theta_x[-1] - theta_x[0]) / (geo.t[-1] - geo.t[0])
1378        omega_y = (theta_y[-1] - theta_y[0]) / (geo.t[-1] - geo.t[0])
1379        return omega_x / (2*np.pi), omega_y / (2*np.pi)
1380    
1381    @staticmethod
1382    def detect_kam_tori(periodic_orbits: List[PeriodicOrbit2D],
1383                       tolerance: float = 0.1) -> Dict:
1384        """
1385        Detect KAM tori from periodic orbits
1386        """
1387        if not periodic_orbits:
1388            return {'n_tori': 0, 'tori': []}
1389        actions = np.array([orb.action for orb in periodic_orbits])
1390        # Cluster by action
1391        if len(actions) > 1:
1392            Z = linkage(actions.reshape(-1, 1), method='ward')
1393            clusters = fcluster(Z, t=tolerance, criterion='distance')
1394            n_tori = len(np.unique(clusters))
1395        else:
1396            n_tori = 1
1397            clusters = [1]
1398        # Analyze each torus
1399        tori = []
1400        for torus_id in np.unique(clusters):
1401            orbits_in_torus = [orb for i, orb in enumerate(periodic_orbits) 
1402                              if clusters[i] == torus_id]
1403            mean_action = np.mean([orb.action for orb in orbits_in_torus])
1404            mean_energy = np.mean([orb.energy for orb in orbits_in_torus])
1405            mean_period = np.mean([orb.period for orb in orbits_in_torus])
1406            stabilities = [orb.stability_1 for orb in orbits_in_torus]
1407            is_stable = np.mean(stabilities) < 0
1408            tori.append({
1409                'id': int(torus_id),
1410                'n_orbits': len(orbits_in_torus),
1411                'action': mean_action,
1412                'energy': mean_energy,
1413                'period': mean_period,
1414                'stable': is_stable
1415            })
1416        return {
1417            'n_tori': n_tori,
1418            'tori': tori
1419        }

Additional analysis tools for 2D systems

@staticmethod
def compute_winding_number(geo: src.geometry_2d.Geodesic2D) -> float:
1358    @staticmethod
1359    def compute_winding_number(geo: Geodesic2D) -> float:
1360        """
1361        Compute winding number around origin
1362        """
1363        angles = np.arctan2(geo.y, geo.x)
1364        angles_unwrapped = np.unwrap(angles)
1365        winding = (angles_unwrapped[-1] - angles_unwrapped[0]) / (2 * np.pi)
1366        return winding

Compute winding number around origin

@staticmethod
def compute_rotation_numbers(geo: src.geometry_2d.Geodesic2D) -> Tuple[float, float]:
1368    @staticmethod
1369    def compute_rotation_numbers(geo: Geodesic2D) -> Tuple[float, float]:
1370        """
1371        Compute rotation numbers (ω_x, ω_y)
1372        """
1373        theta_x = np.arctan2(geo.xi, geo.x)
1374        theta_y = np.arctan2(geo.eta, geo.y)
1375        theta_x = np.unwrap(theta_x)
1376        theta_y = np.unwrap(theta_y)
1377        omega_x = (theta_x[-1] - theta_x[0]) / (geo.t[-1] - geo.t[0])
1378        omega_y = (theta_y[-1] - theta_y[0]) / (geo.t[-1] - geo.t[0])
1379        return omega_x / (2*np.pi), omega_y / (2*np.pi)

Compute rotation numbers (ω_x, ω_y)

@staticmethod
def detect_kam_tori( periodic_orbits: List[src.geometry_2d.PeriodicOrbit2D], tolerance: float = 0.1) -> Dict:
1381    @staticmethod
1382    def detect_kam_tori(periodic_orbits: List[PeriodicOrbit2D],
1383                       tolerance: float = 0.1) -> Dict:
1384        """
1385        Detect KAM tori from periodic orbits
1386        """
1387        if not periodic_orbits:
1388            return {'n_tori': 0, 'tori': []}
1389        actions = np.array([orb.action for orb in periodic_orbits])
1390        # Cluster by action
1391        if len(actions) > 1:
1392            Z = linkage(actions.reshape(-1, 1), method='ward')
1393            clusters = fcluster(Z, t=tolerance, criterion='distance')
1394            n_tori = len(np.unique(clusters))
1395        else:
1396            n_tori = 1
1397            clusters = [1]
1398        # Analyze each torus
1399        tori = []
1400        for torus_id in np.unique(clusters):
1401            orbits_in_torus = [orb for i, orb in enumerate(periodic_orbits) 
1402                              if clusters[i] == torus_id]
1403            mean_action = np.mean([orb.action for orb in orbits_in_torus])
1404            mean_energy = np.mean([orb.energy for orb in orbits_in_torus])
1405            mean_period = np.mean([orb.period for orb in orbits_in_torus])
1406            stabilities = [orb.stability_1 for orb in orbits_in_torus]
1407            is_stable = np.mean(stabilities) < 0
1408            tori.append({
1409                'id': int(torus_id),
1410                'n_orbits': len(orbits_in_torus),
1411                'action': mean_action,
1412                'energy': mean_energy,
1413                'period': mean_period,
1414                'stable': is_stable
1415            })
1416        return {
1417            'n_tori': n_tori,
1418            'tori': tori
1419        }

Detect KAM tori from periodic orbits

class Metric1D:
 29class Metric1D:
 30    """
 31    Riemannian metric on a 1D manifold.
 32    
 33    Represents a metric tensor g₁₁(x) and provides methods for computing
 34    geometric quantities: inverse metric, Christoffel symbols, curvature,
 35    and associated operators.
 36    
 37    Parameters
 38    ----------
 39    g_expr : sympy expression
 40        Symbolic expression for the metric component g₁₁(x).
 41    var_x : sympy symbol
 42        Spatial coordinate variable.
 43    
 44    Attributes
 45    ----------
 46    g_expr : sympy expression
 47        Metric tensor component g₁₁(x).
 48    g_inv_expr : sympy expression
 49        Inverse metric g¹¹(x) = 1/g₁₁(x).
 50    sqrt_det_expr : sympy expression
 51        Square root of determinant √|g| = √g₁₁.
 52    christoffel_expr : sympy expression
 53        Christoffel symbol Γ¹₁₁ = ½(log g₁₁)'.
 54    
 55    Examples
 56    --------
 57    >>> # Flat metric
 58    >>> x = symbols('x', real=True)
 59    >>> metric = Metric1D(1, x)
 60    
 61    >>> # Hyperbolic metric
 62    >>> metric = Metric1D(1/x**2, x)
 63    >>> print(metric.gauss_curvature())
 64    
 65    >>> # From Hamiltonian
 66    >>> p = symbols('p', real=True)
 67    >>> H = p**2 / (2*x**2)  # Kinetic term
 68    >>> metric = Metric1D.from_hamiltonian(H, x, p)
 69    """
 70    
 71    def __init__(self, g_expr, var_x):
 72        self.var_x = var_x
 73        self.g_expr = simplify(g_expr)
 74        self.g_inv_expr = simplify(1 / self.g_expr)
 75        self.sqrt_det_expr = simplify(sqrt(abs(self.g_expr)))
 76        
 77        # Christoffel symbol: Γ¹₁₁ = ½(log g₁₁)'
 78        log_g = log(abs(self.g_expr))
 79        self.christoffel_expr = simplify(diff(log_g, var_x) / 2)
 80        
 81        # Lambdify for numerical evaluation
 82        self.g_func = lambdify(var_x, self.g_expr, 'numpy')
 83        self.g_inv_func = lambdify(var_x, self.g_inv_expr, 'numpy')
 84        self.sqrt_det_func = lambdify(var_x, self.sqrt_det_expr, 'numpy')
 85        self.christoffel_func = lambdify(var_x, self.christoffel_expr, 'numpy')
 86    
 87    @classmethod
 88    def from_hamiltonian(cls, H_expr, var_x, var_p):
 89        """
 90        Extract metric from Hamiltonian kinetic term.
 91        
 92        For a Hamiltonian H = g¹¹(x) p²/2 + V(x), extract the inverse
 93        metric g¹¹ = ∂²H/∂p².
 94        
 95        Parameters
 96        ----------
 97        H_expr : sympy expression
 98            Hamiltonian expression H(x, p).
 99        var_x : sympy symbol
100            Position variable.
101        var_p : sympy symbol
102            Momentum variable.
103        
104        Returns
105        -------
106        Metric1D
107            Metric object with g₁₁ = 1/g¹¹.
108        
109        Examples
110        --------
111        >>> x, p = symbols('x p', real=True)
112        >>> H = p**2/(2*x**2) + x**2/2
113        >>> metric = Metric1D.from_hamiltonian(H, x, p)
114        >>> print(metric.g_expr)
115        x**2
116        """
117        # Extract g¹¹ from kinetic term
118        g_inv = diff(H_expr, var_p, 2)
119        g = simplify(1 / g_inv)
120        return cls(g, var_x)
121    
122    def eval(self, x_vals):
123        """
124        Evaluate metric components at given points.
125        
126        Parameters
127        ----------
128        x_vals : float or ndarray
129            Spatial coordinates.
130        
131        Returns
132        -------
133        dict
134            Dictionary containing 'g', 'g_inv', 'sqrt_det', 'christoffel'.
135        """
136        return {
137            'g': self.g_func(x_vals),
138            'g_inv': self.g_inv_func(x_vals),
139            'sqrt_det': self.sqrt_det_func(x_vals),
140            'christoffel': self.christoffel_func(x_vals)
141        }
142    
143    def gauss_curvature(self):
144        """
145        Compute Gaussian curvature K(x).
146        
147        In 1D (curves in higher-dimensional space), intrinsic curvature
148        vanishes. This returns the extrinsic curvature if embedded.
149        For surfaces, use riemannian_2d.
150        
151        Returns
152        -------
153        sympy expression
154            Curvature K(x) = 0 for intrinsic 1D geometry.
155        
156        Notes
157        -----
158        For a curve parametrized by arc length, the curvature measures
159        how much the curve deviates from being a straight line.
160        """
161        # Intrinsic curvature is zero for 1D
162        return sympify(0)
163    
164    def ricci_scalar(self):
165        """
166        Compute Ricci scalar R(x).
167        
168        Returns
169        -------
170        sympy expression
171            Ricci scalar R = 0 (1D manifold).
172        """
173        return sympify(0)
174    
175    def laplace_beltrami_symbol(self):
176        """
177        Compute symbol of the Laplace-Beltrami operator.
178        
179        The Laplace-Beltrami operator in 1D is:
180            Δg f = (1/√g) d/dx(√g g¹¹ df/dx)
181                 = g¹¹ d²f/dx² + (√g)'/√g · g¹¹ df/dx
182        
183        Returns
184        -------
185        dict
186            Dictionary with 'principal' (g¹¹ ξ²) and 'subprincipal' 
187            (first-order transport term).
188        
189        Examples
190        --------
191        >>> x, xi = symbols('x xi', real=True)
192        >>> metric = Metric1D(x**2, x)
193        >>> lb = metric.laplace_beltrami_symbol()
194        >>> print(lb['principal'])
195        xi**2/x**2
196        """
197        x = self.var_x
198        xi = symbols('xi', real=True)
199        
200        # Principal symbol: g¹¹(x) ξ²
201        principal = self.g_inv_expr * xi**2
202        
203        # Subprincipal symbol (transport term)
204        # Coefficient of first derivative: d(log√g)/dx · g¹¹
205        log_sqrt_g = log(self.sqrt_det_expr)
206        transport_coeff = simplify(diff(log_sqrt_g, x) * self.g_inv_expr)
207        subprincipal = transport_coeff * xi
208        
209        return {
210            'principal': simplify(principal),
211            'subprincipal': simplify(subprincipal),
212            'full': simplify(principal + 1j * subprincipal)
213        }
214    
215    def riemannian_volume(self, x_min, x_max, method='symbolic'):
216        """
217        Compute Riemannian volume of interval [x_min, x_max].
218        
219        Vol([a,b]) = ∫ₐᵇ √g₁₁(x) dx
220        
221        Parameters
222        ----------
223        x_min, x_max : float
224            Interval endpoints.
225        method : {'symbolic', 'numerical'}
226            Integration method.
227        
228        Returns
229        -------
230        float or sympy expression
231            Volume of the interval.
232        
233        Examples
234        --------
235        >>> x = symbols('x', real=True)
236        >>> metric = Metric1D(1, x)  # Flat
237        >>> vol = metric.riemannian_volume(0, 1)
238        >>> print(vol)
239        1
240        """
241        if method == 'symbolic':
242            return integrate(self.sqrt_det_expr, (self.var_x, x_min, x_max))
243        elif method == 'numerical':
244            from scipy.integrate import quad
245            integrand = lambda x: self.sqrt_det_func(x)
246            result, error = quad(integrand, x_min, x_max)
247            return result
248        else:
249            raise ValueError("method must be 'symbolic' or 'numerical'")
250    
251    def arc_length(self, x_min, x_max, method='numerical'):
252        """
253        Compute arc length between two points.
254        
255        L = ∫ₐᵇ √g₁₁(x) dx
256        
257        Parameters
258        ----------
259        x_min, x_max : float
260            Endpoints.
261        method : {'symbolic', 'numerical'}
262            Computation method.
263        
264        Returns
265        -------
266        float
267            Arc length.
268        """
269        return self.riemannian_volume(x_min, x_max, method=method)

Riemannian metric on a 1D manifold.

Represents a metric tensor g₁₁(x) and provides methods for computing geometric quantities: inverse metric, Christoffel symbols, curvature, and associated operators.

Parameters

g_expr : sympy expression Symbolic expression for the metric component g₁₁(x). var_x : sympy symbol Spatial coordinate variable.

Attributes

g_expr : sympy expression Metric tensor component g₁₁(x). g_inv_expr : sympy expression Inverse metric g¹¹(x) = 1/g₁₁(x). sqrt_det_expr : sympy expression Square root of determinant √|g| = √g₁₁. christoffel_expr : sympy expression Christoffel symbol Γ¹₁₁ = ½(log g₁₁)'.

Examples

>>> # Flat metric
>>> x = symbols('x', real=True)
>>> metric = Metric1D(1, x)
>>> # Hyperbolic metric
>>> metric = Metric1D(1/x**2, x)
>>> print(metric.gauss_curvature())
>>> # From Hamiltonian
>>> p = symbols('p', real=True)
>>> H = p**2 / (2*x**2)  # Kinetic term
>>> metric = Metric1D.from_hamiltonian(H, x, p)
Metric1D(g_expr, var_x)
71    def __init__(self, g_expr, var_x):
72        self.var_x = var_x
73        self.g_expr = simplify(g_expr)
74        self.g_inv_expr = simplify(1 / self.g_expr)
75        self.sqrt_det_expr = simplify(sqrt(abs(self.g_expr)))
76        
77        # Christoffel symbol: Γ¹₁₁ = ½(log g₁₁)'
78        log_g = log(abs(self.g_expr))
79        self.christoffel_expr = simplify(diff(log_g, var_x) / 2)
80        
81        # Lambdify for numerical evaluation
82        self.g_func = lambdify(var_x, self.g_expr, 'numpy')
83        self.g_inv_func = lambdify(var_x, self.g_inv_expr, 'numpy')
84        self.sqrt_det_func = lambdify(var_x, self.sqrt_det_expr, 'numpy')
85        self.christoffel_func = lambdify(var_x, self.christoffel_expr, 'numpy')
var_x
g_expr
g_inv_expr
sqrt_det_expr
christoffel_expr
g_func
g_inv_func
sqrt_det_func
christoffel_func
@classmethod
def from_hamiltonian(cls, H_expr, var_x, var_p):
 87    @classmethod
 88    def from_hamiltonian(cls, H_expr, var_x, var_p):
 89        """
 90        Extract metric from Hamiltonian kinetic term.
 91        
 92        For a Hamiltonian H = g¹¹(x) p²/2 + V(x), extract the inverse
 93        metric g¹¹ = ∂²H/∂p².
 94        
 95        Parameters
 96        ----------
 97        H_expr : sympy expression
 98            Hamiltonian expression H(x, p).
 99        var_x : sympy symbol
100            Position variable.
101        var_p : sympy symbol
102            Momentum variable.
103        
104        Returns
105        -------
106        Metric1D
107            Metric object with g₁₁ = 1/g¹¹.
108        
109        Examples
110        --------
111        >>> x, p = symbols('x p', real=True)
112        >>> H = p**2/(2*x**2) + x**2/2
113        >>> metric = Metric1D.from_hamiltonian(H, x, p)
114        >>> print(metric.g_expr)
115        x**2
116        """
117        # Extract g¹¹ from kinetic term
118        g_inv = diff(H_expr, var_p, 2)
119        g = simplify(1 / g_inv)
120        return cls(g, var_x)

Extract metric from Hamiltonian kinetic term.

For a Hamiltonian H = g¹¹(x) p²/2 + V(x), extract the inverse metric g¹¹ = ∂²H/∂p².

Parameters

H_expr : sympy expression Hamiltonian expression H(x, p). var_x : sympy symbol Position variable. var_p : sympy symbol Momentum variable.

Returns

Metric1D Metric object with g₁₁ = 1/g¹¹.

Examples

>>> x, p = symbols('x p', real=True)
>>> H = p**2/(2*x**2) + x**2/2
>>> metric = Metric1D.from_hamiltonian(H, x, p)
>>> print(metric.g_expr)
x**2
def eval(self, x_vals):
122    def eval(self, x_vals):
123        """
124        Evaluate metric components at given points.
125        
126        Parameters
127        ----------
128        x_vals : float or ndarray
129            Spatial coordinates.
130        
131        Returns
132        -------
133        dict
134            Dictionary containing 'g', 'g_inv', 'sqrt_det', 'christoffel'.
135        """
136        return {
137            'g': self.g_func(x_vals),
138            'g_inv': self.g_inv_func(x_vals),
139            'sqrt_det': self.sqrt_det_func(x_vals),
140            'christoffel': self.christoffel_func(x_vals)
141        }

Evaluate metric components at given points.

Parameters

x_vals : float or ndarray Spatial coordinates.

Returns

dict Dictionary containing 'g', 'g_inv', 'sqrt_det', 'christoffel'.

def gauss_curvature(self):
143    def gauss_curvature(self):
144        """
145        Compute Gaussian curvature K(x).
146        
147        In 1D (curves in higher-dimensional space), intrinsic curvature
148        vanishes. This returns the extrinsic curvature if embedded.
149        For surfaces, use riemannian_2d.
150        
151        Returns
152        -------
153        sympy expression
154            Curvature K(x) = 0 for intrinsic 1D geometry.
155        
156        Notes
157        -----
158        For a curve parametrized by arc length, the curvature measures
159        how much the curve deviates from being a straight line.
160        """
161        # Intrinsic curvature is zero for 1D
162        return sympify(0)

Compute Gaussian curvature K(x).

In 1D (curves in higher-dimensional space), intrinsic curvature vanishes. This returns the extrinsic curvature if embedded. For surfaces, use riemannian_2d.

Returns

sympy expression Curvature K(x) = 0 for intrinsic 1D geometry.

Notes

For a curve parametrized by arc length, the curvature measures how much the curve deviates from being a straight line.

def ricci_scalar(self):
164    def ricci_scalar(self):
165        """
166        Compute Ricci scalar R(x).
167        
168        Returns
169        -------
170        sympy expression
171            Ricci scalar R = 0 (1D manifold).
172        """
173        return sympify(0)

Compute Ricci scalar R(x).

Returns

sympy expression Ricci scalar R = 0 (1D manifold).

def laplace_beltrami_symbol(self):
175    def laplace_beltrami_symbol(self):
176        """
177        Compute symbol of the Laplace-Beltrami operator.
178        
179        The Laplace-Beltrami operator in 1D is:
180            Δg f = (1/√g) d/dx(√g g¹¹ df/dx)
181                 = g¹¹ d²f/dx² + (√g)'/√g · g¹¹ df/dx
182        
183        Returns
184        -------
185        dict
186            Dictionary with 'principal' (g¹¹ ξ²) and 'subprincipal' 
187            (first-order transport term).
188        
189        Examples
190        --------
191        >>> x, xi = symbols('x xi', real=True)
192        >>> metric = Metric1D(x**2, x)
193        >>> lb = metric.laplace_beltrami_symbol()
194        >>> print(lb['principal'])
195        xi**2/x**2
196        """
197        x = self.var_x
198        xi = symbols('xi', real=True)
199        
200        # Principal symbol: g¹¹(x) ξ²
201        principal = self.g_inv_expr * xi**2
202        
203        # Subprincipal symbol (transport term)
204        # Coefficient of first derivative: d(log√g)/dx · g¹¹
205        log_sqrt_g = log(self.sqrt_det_expr)
206        transport_coeff = simplify(diff(log_sqrt_g, x) * self.g_inv_expr)
207        subprincipal = transport_coeff * xi
208        
209        return {
210            'principal': simplify(principal),
211            'subprincipal': simplify(subprincipal),
212            'full': simplify(principal + 1j * subprincipal)
213        }

Compute symbol of the Laplace-Beltrami operator.

The Laplace-Beltrami operator in 1D is: Δg f = (1/√g) d/dx(√g g¹¹ df/dx) = g¹¹ d²f/dx² + (√g)'/√g · g¹¹ df/dx

Returns

dict Dictionary with 'principal' (g¹¹ ξ²) and 'subprincipal' (first-order transport term).

Examples

>>> x, xi = symbols('x xi', real=True)
>>> metric = Metric1D(x**2, x)
>>> lb = metric.laplace_beltrami_symbol()
>>> print(lb['principal'])
xi**2/x**2
def riemannian_volume(self, x_min, x_max, method='symbolic'):
215    def riemannian_volume(self, x_min, x_max, method='symbolic'):
216        """
217        Compute Riemannian volume of interval [x_min, x_max].
218        
219        Vol([a,b]) = ∫ₐᵇ √g₁₁(x) dx
220        
221        Parameters
222        ----------
223        x_min, x_max : float
224            Interval endpoints.
225        method : {'symbolic', 'numerical'}
226            Integration method.
227        
228        Returns
229        -------
230        float or sympy expression
231            Volume of the interval.
232        
233        Examples
234        --------
235        >>> x = symbols('x', real=True)
236        >>> metric = Metric1D(1, x)  # Flat
237        >>> vol = metric.riemannian_volume(0, 1)
238        >>> print(vol)
239        1
240        """
241        if method == 'symbolic':
242            return integrate(self.sqrt_det_expr, (self.var_x, x_min, x_max))
243        elif method == 'numerical':
244            from scipy.integrate import quad
245            integrand = lambda x: self.sqrt_det_func(x)
246            result, error = quad(integrand, x_min, x_max)
247            return result
248        else:
249            raise ValueError("method must be 'symbolic' or 'numerical'")

Compute Riemannian volume of interval [x_min, x_max].

Vol([a,b]) = ∫ₐᵇ √g₁₁(x) dx

Parameters

x_min, x_max : float Interval endpoints. method : {'symbolic', 'numerical'} Integration method.

Returns

float or sympy expression Volume of the interval.

Examples

>>> x = symbols('x', real=True)
>>> metric = Metric1D(1, x)  # Flat
>>> vol = metric.riemannian_volume(0, 1)
>>> print(vol)
1
def arc_length(self, x_min, x_max, method='numerical'):
251    def arc_length(self, x_min, x_max, method='numerical'):
252        """
253        Compute arc length between two points.
254        
255        L = ∫ₐᵇ √g₁₁(x) dx
256        
257        Parameters
258        ----------
259        x_min, x_max : float
260            Endpoints.
261        method : {'symbolic', 'numerical'}
262            Computation method.
263        
264        Returns
265        -------
266        float
267            Arc length.
268        """
269        return self.riemannian_volume(x_min, x_max, method=method)

Compute arc length between two points.

L = ∫ₐᵇ √g₁₁(x) dx

Parameters

x_min, x_max : float Endpoints. method : {'symbolic', 'numerical'} Computation method.

Returns

float Arc length.

def geodesic_integrator(metric, x0, v0, tspan, method='rk4', n_steps=1000):
296def geodesic_integrator(metric, x0, v0, tspan, method='rk4', n_steps=1000):
297    """
298    Integrate geodesic equations.
299    
300    Solves: ẍ + Γ¹₁₁(x) ẋ² = 0
301    
302    Converted to first-order system:
303        ẋ = v
304        v̇ = -Γ¹₁₁(x) v²
305    
306    Parameters
307    ----------
308    metric : Metric1D
309        Riemannian metric.
310    x0 : float
311        Initial position.
312    v0 : float
313        Initial velocity dx/dt.
314    tspan : tuple
315        Time interval (t_start, t_end).
316    method : {'rk4', 'symplectic', 'adaptive'}
317        Integration method.
318    n_steps : int
319        Number of time steps.
320    
321    Returns
322    -------
323    dict
324        Dictionary with 't', 'x', 'v' arrays.
325    
326    Examples
327    --------
328    >>> x = symbols('x', real=True)
329    >>> metric = Metric1D(1, x)  # Flat
330    >>> traj = geodesic_integrator(metric, 0.0, 1.0, (0, 10))
331    >>> plt.plot(traj['t'], traj['x'])
332    
333    Notes
334    -----
335    - For flat metric, geodesics are straight lines.
336    - Symplectic integrators preserve energy better for long integrations.
337    """
338    from scipy.integrate import solve_ivp
339    
340    Gamma_func = metric.christoffel_func
341    
342    def geodesic_ode(t, y):
343        x, v = y
344        dxdt = v
345        dvdt = -Gamma_func(x) * v**2
346        return [dxdt, dvdt]
347    
348    if method == 'rk4' or method == 'adaptive':
349        sol = solve_ivp(
350            geodesic_ode, 
351            tspan, 
352            [x0, v0],
353            method='RK45' if method == 'adaptive' else 'RK23',
354            t_eval=np.linspace(tspan[0], tspan[1], n_steps)
355        )
356        return {
357            't': sol.t,
358            'x': sol.y[0],
359            'v': sol.y[1]
360        }
361    
362    elif method == 'symplectic':
363        # Symplectic Euler for Hamiltonian formulation
364        # H = g¹¹(x)/2 · p²
365        # ẋ = g¹¹ p
366        # ṗ = -½ (∂g¹¹/∂x) p²
367        
368        dt = (tspan[1] - tspan[0]) / n_steps
369        t_vals = np.linspace(tspan[0], tspan[1], n_steps)
370        x_vals = np.zeros(n_steps)
371        p_vals = np.zeros(n_steps)
372        
373        # Initial momentum: p = v / g¹¹
374        g_inv_0 = metric.g_inv_func(x0)
375        p0 = v0 / g_inv_0
376        
377        x_vals[0] = x0
378        p_vals[0] = p0
379        
380        # Prepare derivative of g¹¹
381        g_inv_prime = lambdify(
382            metric.var_x,
383            diff(metric.g_inv_expr, metric.var_x),
384            'numpy'
385        )
386        
387        for i in range(n_steps - 1):
388            x = x_vals[i]
389            p = p_vals[i]
390            
391            # Symplectic Euler step
392            g_inv = metric.g_inv_func(x)
393            p_new = p - dt * 0.5 * g_inv_prime(x) * p**2
394            x_new = x + dt * g_inv * p_new
395            
396            x_vals[i+1] = x_new
397            p_vals[i+1] = p_new
398        
399        # Convert momentum back to velocity
400        v_vals = np.array([
401            metric.g_inv_func(x) * p 
402            for x, p in zip(x_vals, p_vals)
403        ])
404        
405        return {
406            't': t_vals,
407            'x': x_vals,
408            'v': v_vals,
409            'p': p_vals
410        }
411    
412    else:
413        raise ValueError("method must be 'rk4', 'symplectic', or 'adaptive'")

Integrate geodesic equations.

Solves: ẍ + Γ¹₁₁(x) ẋ² = 0

Converted to first-order system: ẋ = v v̇ = -Γ¹₁₁(x) v²

Parameters

metric : Metric1D Riemannian metric. x0 : float Initial position. v0 : float Initial velocity dx/dt. tspan : tuple Time interval (t_start, t_end). method : {'rk4', 'symplectic', 'adaptive'} Integration method. n_steps : int Number of time steps.

Returns

dict Dictionary with 't', 'x', 'v' arrays.

Examples

>>> x = symbols('x', real=True)
>>> metric = Metric1D(1, x)  # Flat
>>> traj = geodesic_integrator(metric, 0.0, 1.0, (0, 10))
>>> plt.plot(traj['t'], traj['x'])

Notes

  • For flat metric, geodesics are straight lines.
  • Symplectic integrators preserve energy better for long integrations.
def laplace_beltrami(metric):
518def laplace_beltrami(metric):
519    """
520    Construct Laplace-Beltrami operator as a pseudo-differential operator.
521    
522    Returns a symbol compatible with psiop.PseudoDifferentialOperator.
523    
524    Parameters
525    ----------
526    metric : Metric1D
527        Riemannian metric.
528    
529    Returns
530    -------
531    dict
532        Symbol components for use with PseudoDifferentialOperator.
533    
534    Examples
535    --------
536    >>> from psiop import PseudoDifferentialOperator
537    >>> x = symbols('x', real=True)
538    >>> metric = Metric1D(x**2, x)
539    >>> lb_symbol = laplace_beltrami(metric)
540    >>> op = PseudoDifferentialOperator(
541    ...     lb_symbol['full'], [x], mode='symbol'
542    ... )
543    """
544    return metric.laplace_beltrami_symbol()

Construct Laplace-Beltrami operator as a pseudo-differential operator.

Returns a symbol compatible with psiop.PseudoDifferentialOperator.

Parameters

metric : Metric1D Riemannian metric.

Returns

dict Symbol components for use with PseudoDifferentialOperator.

Examples

>>> from psiop import PseudoDifferentialOperator
>>> x = symbols('x', real=True)
>>> metric = Metric1D(x**2, x)
>>> lb_symbol = laplace_beltrami(metric)
>>> op = PseudoDifferentialOperator(
...     lb_symbol['full'], [x], mode='symbol'
... )
class Metric2D:
 30class Metric2D:
 31    """
 32    Riemannian metric tensor on a 2D manifold.
 33    
 34    Represents a metric tensor as a 2×2 matrix:
 35        g = [[g₁₁, g₁₂],
 36             [g₁₂, g₂₂]]
 37    
 38    Parameters
 39    ----------
 40    g_matrix : 2×2 sympy Matrix or list
 41        Metric tensor components [[g₁₁, g₁₂], [g₁₂, g₂₂]].
 42    vars_xy : tuple of sympy symbols
 43        Coordinate variables (x, y).
 44    
 45    Attributes
 46    ----------
 47    g_matrix : sympy Matrix
 48        Metric tensor gᵢⱼ.
 49    g_inv_matrix : sympy Matrix
 50        Inverse metric g^ij.
 51    det_g : sympy expression
 52        Determinant |g|.
 53    sqrt_det_g : sympy expression
 54        √|g| for volume forms.
 55    christoffel : dict
 56        Christoffel symbols Γⁱⱼₖ.
 57    
 58    Examples
 59    --------
 60    >>> # Euclidean metric
 61    >>> x, y = symbols('x y', real=True)
 62    >>> g = Matrix([[1, 0], [0, 1]])
 63    >>> metric = Metric2D(g, (x, y))
 64    
 65    >>> # Polar coordinates
 66    >>> r, theta = symbols('r theta', real=True, positive=True)
 67    >>> g_polar = Matrix([[1, 0], [0, r**2]])
 68    >>> metric = Metric2D(g_polar, (r, theta))
 69    >>> print(metric.gauss_curvature())
 70    
 71    >>> # From Hamiltonian
 72    >>> p_x, p_y = symbols('p_x p_y', real=True)
 73    >>> H = (p_x**2 + p_y**2)/(2*x**2)
 74    >>> metric = Metric2D.from_hamiltonian(H, (x,y), (p_x,p_y))
 75    """
 76    
 77    def __init__(self, g_matrix, vars_xy):
 78        if not isinstance(g_matrix, Matrix):
 79            g_matrix = Matrix(g_matrix)
 80    
 81        if g_matrix.shape != (2, 2):
 82            raise ValueError("Metric2D requires a 2×2 metric tensor")
 83    
 84        if len(vars_xy) != 2:
 85            raise ValueError("Metric2D requires exactly two coordinates (x, y)")
 86        self.vars_xy = vars_xy
 87        self.x, self.y = vars_xy
 88        
 89        if not isinstance(g_matrix, Matrix):
 90            g_matrix = Matrix(g_matrix)
 91        
 92        self.g_matrix = simplify(g_matrix)
 93        self.det_g = simplify(self.g_matrix.det())
 94        self.sqrt_det_g = simplify(sqrt(abs(self.det_g)))
 95        self.g_inv_matrix = simplify(self.g_matrix.inv())
 96        
 97        # Compute Christoffel symbols
 98        self.christoffel = self._compute_christoffel()
 99        
100        # Lambdify for numerical evaluation
101        self._lambdify_all()
102    
103    def _compute_christoffel(self):
104        """
105        Compute all Christoffel symbols Γⁱⱼₖ.
106        
107        Γⁱⱼₖ = ½ g^iℓ (∂ⱼgₖℓ + ∂ₖgⱼℓ - ∂ℓgⱼₖ)
108        
109        Returns
110        -------
111        dict
112            Nested dict: christoffel[i][j][k] = Γⁱⱼₖ
113        """
114        x, y = self.vars_xy
115        g = self.g_matrix
116        g_inv = self.g_inv_matrix
117        
118        Gamma = {}
119        for i in range(2):
120            Gamma[i] = {}
121            for j in range(2):
122                Gamma[i][j] = {}
123                for k in range(2):
124                    expr = 0
125                    for ell in range(2):
126                        term1 = diff(g[k, ell], [x, y][j])
127                        term2 = diff(g[j, ell], [x, y][k])
128                        term3 = diff(g[j, k], [x, y][ell])
129                        expr += g_inv[i, ell] * (term1 + term2 - term3) / 2
130                    Gamma[i][j][k] = simplify(expr)
131        
132        return Gamma
133    
134    def _lambdify_all(self):
135        """Prepare numerical functions for all geometric quantities."""
136        x, y = self.vars_xy
137        
138        # Metric components
139        self.g_func = {
140            (i, j): lambdify((x, y), self.g_matrix[i, j], 'numpy')
141            for i in range(2) for j in range(2)
142        }
143        
144        self.g_inv_func = {
145            (i, j): lambdify((x, y), self.g_inv_matrix[i, j], 'numpy')
146            for i in range(2) for j in range(2)
147        }
148        
149        self.det_g_func = lambdify((x, y), self.det_g, 'numpy')
150        self.sqrt_det_g_func = lambdify((x, y), self.sqrt_det_g, 'numpy')
151        
152        # Christoffel symbols
153        self.christoffel_func = {}
154        for i in range(2):
155            self.christoffel_func[i] = {}
156            for j in range(2):
157                self.christoffel_func[i][j] = {}
158                for k in range(2):
159                    self.christoffel_func[i][j][k] = lambdify(
160                        (x, y), self.christoffel[i][j][k], 'numpy'
161                    )
162    
163    @classmethod
164    def from_hamiltonian(cls, H_expr, vars_xy, vars_p):
165        """
166        Extract metric from Hamiltonian kinetic term.
167        
168        For H = ½ g^ij pᵢ pⱼ + V, extract inverse metric from Hessian:
169            g^ij = ∂²H/∂pᵢ∂pⱼ
170        
171        Parameters
172        ----------
173        H_expr : sympy expression
174            Hamiltonian H(x, y, pₓ, pᵧ).
175        vars_xy : tuple
176            Position variables (x, y).
177        vars_p : tuple
178            Momentum variables (pₓ, pᵧ).
179        
180        Returns
181        -------
182        Metric2D
183            Metric with gᵢⱼ = (g^ij)⁻¹.
184        
185        Examples
186        --------
187        >>> x, y, px, py = symbols('x y p_x p_y', real=True)
188        >>> H = (px**2 + py**2)/(2*x**2)
189        >>> metric = Metric2D.from_hamiltonian(H, (x,y), (px,py))
190        """
191        px, py = vars_p
192        
193        # Compute Hessian
194        g_inv_11 = diff(H_expr, px, 2)
195        g_inv_12 = diff(H_expr, px, py)
196        g_inv_22 = diff(H_expr, py, 2)
197        
198        g_inv = Matrix([[g_inv_11, g_inv_12],
199                        [g_inv_12, g_inv_22]])
200        
201        g = simplify(g_inv.inv())
202        return cls(g, vars_xy)
203    
204    def eval(self, x_vals, y_vals):
205        """
206        Evaluate metric components at given points.
207        
208        Parameters
209        ----------
210        x_vals, y_vals : float or ndarray
211            Coordinate values.
212        
213        Returns
214        -------
215        dict
216            Dictionary containing metric tensors and geometric quantities.
217        """
218        result = {
219            'g': np.zeros((2, 2, *np.shape(x_vals))),
220            'g_inv': np.zeros((2, 2, *np.shape(x_vals))),
221            'det_g': self.det_g_func(x_vals, y_vals),
222            'sqrt_det_g': self.sqrt_det_g_func(x_vals, y_vals),
223            'christoffel': {}
224        }
225        
226        for i in range(2):
227            for j in range(2):
228                result['g'][i, j] = self.g_func[(i, j)](x_vals, y_vals)
229                result['g_inv'][i, j] = self.g_inv_func[(i, j)](x_vals, y_vals)
230        
231        for i in range(2):
232            result['christoffel'][i] = {}
233            for j in range(2):
234                result['christoffel'][i][j] = {}
235                for k in range(2):
236                    result['christoffel'][i][j][k] = \
237                        self.christoffel_func[i][j][k](x_vals, y_vals)
238        
239        return result
240    
241    def gauss_curvature(self):
242        """
243        Compute Gaussian curvature K.
244        
245        For a 2D Riemannian manifold, the Gaussian curvature is:
246            K = R₁₂₁₂ / |g|
247        
248        where R₁₂₁₂ is a component of the Riemann curvature tensor.
249        
250        Returns
251        -------
252        sympy expression
253            Gaussian curvature K(x, y).
254        
255        Notes
256        -----
257        By Gauss-Bonnet theorem: ∫∫_M K dA = 2π χ(M)
258        where χ is the Euler characteristic.
259        
260        Examples
261        --------
262        >>> x, y = symbols('x y', real=True)
263        >>> g = Matrix([[1, 0], [0, 1]])
264        >>> metric = Metric2D(g, (x, y))
265        >>> print(metric.gauss_curvature())
266        0
267        """
268        # Ensure we have the full Riemann tensor
269        # R^i_{jkl}
270        R = self.riemann_tensor()
271        g = self.g_matrix
272
273        # Calculate the covariant component R_xyxy (or R_1212)
274        # Indices: x=0, y=1
275        # R_xyxy = g_xx * R^x_yxy + g_xy * R^y_yxy
276        # R^i_{jkl} with j=1 (y), k=0 (x), l=1 (y)
277
278        R_x_yxy = R[0][1][0][1]  # R^0_{101}
279        R_y_yxy = R[1][1][0][1]  # R^1_{101}
280
281        # Lowering index: R_{0101} = g_{0m} R^m_{101}
282        R_xyxy = g[0,0] * R_x_yxy + g[0,1] * R_y_yxy
283
284        # K = R_1212 / det(g)
285        K = simplify(R_xyxy / self.det_g)
286        
287        return K
288    
289    def riemann_tensor(self):
290        """
291        Compute Riemann curvature tensor Rⁱⱼₖₗ.
292        
293        Returns
294        -------
295        dict
296            Nested dict with all non-zero components.
297        
298        Notes
299        -----
300        In 2D, only one independent component exists (up to symmetries).
301        """
302        x, y = self.vars_xy
303        Gamma = self.christoffel
304        
305        R = {}
306        for i in range(2):
307            R[i] = {}
308            for j in range(2):
309                R[i][j] = {}
310                for k in range(2):
311                    R[i][j][k] = {}
312                    for ell in range(2):
313                        expr = diff(Gamma[i][j][ell], [x, y][k])
314                        expr -= diff(Gamma[i][j][k], [x, y][ell])
315                        
316                        for m in range(2):
317                            expr += Gamma[i][m][k] * Gamma[m][j][ell]
318                            expr -= Gamma[i][m][ell] * Gamma[m][j][k]
319                        
320                        R[i][j][k][ell] = simplify(expr)
321        
322        return R
323    
324    def ricci_tensor(self):
325        """
326        Compute Ricci curvature tensor Rᵢⱼ.
327        
328        Rᵢⱼ = Rᵏᵢₖⱼ (contraction of Riemann tensor)
329        
330        Returns
331        -------
332        sympy Matrix
333            2×2 Ricci tensor.
334        """
335        R_full = self.riemann_tensor()
336        
337        Ric = zeros(2, 2)
338        for i in range(2):
339            for j in range(2):
340                for k in range(2):
341                    Ric[i, j] += R_full[k][i][k][j]
342        
343        return simplify(Ric)
344    
345    def ricci_scalar(self):
346        """
347        Compute scalar curvature R.
348        
349        R = g^ij Rᵢⱼ
350        
351        For 2D surfaces: R = 2K (twice the Gaussian curvature).
352        
353        Returns
354        -------
355        sympy expression
356            Scalar curvature R(x, y).
357        """
358        Ric = self.ricci_tensor()
359        g_inv = self.g_inv_matrix
360        
361        R = 0
362        for i in range(2):
363            for j in range(2):
364                R += g_inv[i, j] * Ric[i, j]
365        
366        return simplify(R)
367    
368    def laplace_beltrami_symbol(self):
369        """
370        Compute symbol of Laplace-Beltrami operator.
371        
372        Principal symbol: g^ij ξᵢ ξⱼ
373        Subprincipal: transport terms from √|g| factor
374        
375        Returns
376        -------
377        dict
378            Symbol components: 'principal', 'subprincipal', 'full'.
379        
380        Examples
381        --------
382        >>> x, y, xi, eta = symbols('x y xi eta', real=True)
383        >>> g = Matrix([[1, 0], [0, 1]])
384        >>> metric = Metric2D(g, (x, y))
385        >>> symbol = metric.laplace_beltrami_symbol()
386        >>> print(symbol['principal'])
387        xi**2 + eta**2
388        """
389        x, y = self.vars_xy
390        xi, eta = symbols('xi eta', real=True)
391        
392        g_inv = self.g_inv_matrix
393        
394        # Principal symbol
395        principal = (g_inv[0,0] * xi**2 + 
396                    2 * g_inv[0,1] * xi * eta +
397                    g_inv[1,1] * eta**2)
398        
399        # Subprincipal (from divergence structure)
400        # ∇·(√g g^ij ∇u) = √g g^ij ∂ᵢ∂ⱼu + (∂ᵢ√g g^ij) ∂ⱼu
401        sqrt_g = self.sqrt_det_g
402        
403        coeff_x = diff(sqrt_g * g_inv[0,0], x) + diff(sqrt_g * g_inv[0,1], y)
404        coeff_y = diff(sqrt_g * g_inv[1,0], x) + diff(sqrt_g * g_inv[1,1], y)
405        
406        subprincipal = simplify((coeff_x * xi + coeff_y * eta) / sqrt_g)
407        
408        return {
409            'principal': simplify(principal),
410            'subprincipal': simplify(subprincipal),
411            'full': simplify(principal + 1j * subprincipal)
412        }
413    
414    def riemannian_volume(self, domain, method='numerical'):
415        """
416        Compute Riemannian volume of a domain.
417        
418        Vol(Ω) = ∫∫_Ω √|g| dx dy
419        
420        Parameters
421        ----------
422        domain : tuple
423            For rectangular: ((x_min, x_max), (y_min, y_max)).
424            For custom: callable that defines integration region.
425        method : {'numerical', 'symbolic'}
426            Integration method.
427        
428        Returns
429        -------
430        float or sympy expression
431            Volume of the domain.
432        """
433        x, y = self.vars_xy
434        sqrt_g = self.sqrt_det_g
435        
436        if method == 'symbolic':
437            (x_min, x_max), (y_min, y_max) = domain
438            return integrate(sqrt_g, (x, x_min, x_max), (y, y_min, y_max))
439        
440        elif method == 'numerical':
441            from scipy.integrate import dblquad
442            (x_min, x_max), (y_min, y_max) = domain
443            
444            integrand = lambda y, x: self.sqrt_det_g_func(x, y)
445            result, error = dblquad(integrand, x_min, x_max, y_min, y_max)
446            return result
447        
448        else:
449            raise ValueError("method must be 'symbolic' or 'numerical'")

Riemannian metric tensor on a 2D manifold.

Represents a metric tensor as a 2×2 matrix: g = [[g₁₁, g₁₂], [g₁₂, g₂₂]]

Parameters

g_matrix : 2×2 sympy Matrix or list Metric tensor components [[g₁₁, g₁₂], [g₁₂, g₂₂]]. vars_xy : tuple of sympy symbols Coordinate variables (x, y).

Attributes

g_matrix : sympy Matrix Metric tensor gᵢⱼ. g_inv_matrix : sympy Matrix Inverse metric g^ij. det_g : sympy expression Determinant |g|. sqrt_det_g : sympy expression √|g| for volume forms. christoffel : dict Christoffel symbols Γⁱⱼₖ.

Examples

>>> # Euclidean metric
>>> x, y = symbols('x y', real=True)
>>> g = Matrix([[1, 0], [0, 1]])
>>> metric = Metric2D(g, (x, y))
>>> # Polar coordinates
>>> r, theta = symbols('r theta', real=True, positive=True)
>>> g_polar = Matrix([[1, 0], [0, r**2]])
>>> metric = Metric2D(g_polar, (r, theta))
>>> print(metric.gauss_curvature())
>>> # From Hamiltonian
>>> p_x, p_y = symbols('p_x p_y', real=True)
>>> H = (p_x**2 + p_y**2)/(2*x**2)
>>> metric = Metric2D.from_hamiltonian(H, (x,y), (p_x,p_y))
Metric2D(g_matrix, vars_xy)
 77    def __init__(self, g_matrix, vars_xy):
 78        if not isinstance(g_matrix, Matrix):
 79            g_matrix = Matrix(g_matrix)
 80    
 81        if g_matrix.shape != (2, 2):
 82            raise ValueError("Metric2D requires a 2×2 metric tensor")
 83    
 84        if len(vars_xy) != 2:
 85            raise ValueError("Metric2D requires exactly two coordinates (x, y)")
 86        self.vars_xy = vars_xy
 87        self.x, self.y = vars_xy
 88        
 89        if not isinstance(g_matrix, Matrix):
 90            g_matrix = Matrix(g_matrix)
 91        
 92        self.g_matrix = simplify(g_matrix)
 93        self.det_g = simplify(self.g_matrix.det())
 94        self.sqrt_det_g = simplify(sqrt(abs(self.det_g)))
 95        self.g_inv_matrix = simplify(self.g_matrix.inv())
 96        
 97        # Compute Christoffel symbols
 98        self.christoffel = self._compute_christoffel()
 99        
100        # Lambdify for numerical evaluation
101        self._lambdify_all()
vars_xy
g_matrix
det_g
sqrt_det_g
g_inv_matrix
christoffel
@classmethod
def from_hamiltonian(cls, H_expr, vars_xy, vars_p):
163    @classmethod
164    def from_hamiltonian(cls, H_expr, vars_xy, vars_p):
165        """
166        Extract metric from Hamiltonian kinetic term.
167        
168        For H = ½ g^ij pᵢ pⱼ + V, extract inverse metric from Hessian:
169            g^ij = ∂²H/∂pᵢ∂pⱼ
170        
171        Parameters
172        ----------
173        H_expr : sympy expression
174            Hamiltonian H(x, y, pₓ, pᵧ).
175        vars_xy : tuple
176            Position variables (x, y).
177        vars_p : tuple
178            Momentum variables (pₓ, pᵧ).
179        
180        Returns
181        -------
182        Metric2D
183            Metric with gᵢⱼ = (g^ij)⁻¹.
184        
185        Examples
186        --------
187        >>> x, y, px, py = symbols('x y p_x p_y', real=True)
188        >>> H = (px**2 + py**2)/(2*x**2)
189        >>> metric = Metric2D.from_hamiltonian(H, (x,y), (px,py))
190        """
191        px, py = vars_p
192        
193        # Compute Hessian
194        g_inv_11 = diff(H_expr, px, 2)
195        g_inv_12 = diff(H_expr, px, py)
196        g_inv_22 = diff(H_expr, py, 2)
197        
198        g_inv = Matrix([[g_inv_11, g_inv_12],
199                        [g_inv_12, g_inv_22]])
200        
201        g = simplify(g_inv.inv())
202        return cls(g, vars_xy)

Extract metric from Hamiltonian kinetic term.

For H = ½ g^ij pᵢ pⱼ + V, extract inverse metric from Hessian: g^ij = ∂²H/∂pᵢ∂pⱼ

Parameters

H_expr : sympy expression Hamiltonian H(x, y, pₓ, pᵧ). vars_xy : tuple Position variables (x, y). vars_p : tuple Momentum variables (pₓ, pᵧ).

Returns

Metric2D Metric with gᵢⱼ = (g^ij)⁻¹.

Examples

>>> x, y, px, py = symbols('x y p_x p_y', real=True)
>>> H = (px**2 + py**2)/(2*x**2)
>>> metric = Metric2D.from_hamiltonian(H, (x,y), (px,py))
def eval(self, x_vals, y_vals):
204    def eval(self, x_vals, y_vals):
205        """
206        Evaluate metric components at given points.
207        
208        Parameters
209        ----------
210        x_vals, y_vals : float or ndarray
211            Coordinate values.
212        
213        Returns
214        -------
215        dict
216            Dictionary containing metric tensors and geometric quantities.
217        """
218        result = {
219            'g': np.zeros((2, 2, *np.shape(x_vals))),
220            'g_inv': np.zeros((2, 2, *np.shape(x_vals))),
221            'det_g': self.det_g_func(x_vals, y_vals),
222            'sqrt_det_g': self.sqrt_det_g_func(x_vals, y_vals),
223            'christoffel': {}
224        }
225        
226        for i in range(2):
227            for j in range(2):
228                result['g'][i, j] = self.g_func[(i, j)](x_vals, y_vals)
229                result['g_inv'][i, j] = self.g_inv_func[(i, j)](x_vals, y_vals)
230        
231        for i in range(2):
232            result['christoffel'][i] = {}
233            for j in range(2):
234                result['christoffel'][i][j] = {}
235                for k in range(2):
236                    result['christoffel'][i][j][k] = \
237                        self.christoffel_func[i][j][k](x_vals, y_vals)
238        
239        return result

Evaluate metric components at given points.

Parameters

x_vals, y_vals : float or ndarray Coordinate values.

Returns

dict Dictionary containing metric tensors and geometric quantities.

def gauss_curvature(self):
241    def gauss_curvature(self):
242        """
243        Compute Gaussian curvature K.
244        
245        For a 2D Riemannian manifold, the Gaussian curvature is:
246            K = R₁₂₁₂ / |g|
247        
248        where R₁₂₁₂ is a component of the Riemann curvature tensor.
249        
250        Returns
251        -------
252        sympy expression
253            Gaussian curvature K(x, y).
254        
255        Notes
256        -----
257        By Gauss-Bonnet theorem: ∫∫_M K dA = 2π χ(M)
258        where χ is the Euler characteristic.
259        
260        Examples
261        --------
262        >>> x, y = symbols('x y', real=True)
263        >>> g = Matrix([[1, 0], [0, 1]])
264        >>> metric = Metric2D(g, (x, y))
265        >>> print(metric.gauss_curvature())
266        0
267        """
268        # Ensure we have the full Riemann tensor
269        # R^i_{jkl}
270        R = self.riemann_tensor()
271        g = self.g_matrix
272
273        # Calculate the covariant component R_xyxy (or R_1212)
274        # Indices: x=0, y=1
275        # R_xyxy = g_xx * R^x_yxy + g_xy * R^y_yxy
276        # R^i_{jkl} with j=1 (y), k=0 (x), l=1 (y)
277
278        R_x_yxy = R[0][1][0][1]  # R^0_{101}
279        R_y_yxy = R[1][1][0][1]  # R^1_{101}
280
281        # Lowering index: R_{0101} = g_{0m} R^m_{101}
282        R_xyxy = g[0,0] * R_x_yxy + g[0,1] * R_y_yxy
283
284        # K = R_1212 / det(g)
285        K = simplify(R_xyxy / self.det_g)
286        
287        return K

Compute Gaussian curvature K.

For a 2D Riemannian manifold, the Gaussian curvature is: K = R₁₂₁₂ / |g|

where R₁₂₁₂ is a component of the Riemann curvature tensor.

Returns

sympy expression Gaussian curvature K(x, y).

Notes

By Gauss-Bonnet theorem: ∫∫_M K dA = 2π χ(M) where χ is the Euler characteristic.

Examples

>>> x, y = symbols('x y', real=True)
>>> g = Matrix([[1, 0], [0, 1]])
>>> metric = Metric2D(g, (x, y))
>>> print(metric.gauss_curvature())
0
def riemann_tensor(self):
289    def riemann_tensor(self):
290        """
291        Compute Riemann curvature tensor Rⁱⱼₖₗ.
292        
293        Returns
294        -------
295        dict
296            Nested dict with all non-zero components.
297        
298        Notes
299        -----
300        In 2D, only one independent component exists (up to symmetries).
301        """
302        x, y = self.vars_xy
303        Gamma = self.christoffel
304        
305        R = {}
306        for i in range(2):
307            R[i] = {}
308            for j in range(2):
309                R[i][j] = {}
310                for k in range(2):
311                    R[i][j][k] = {}
312                    for ell in range(2):
313                        expr = diff(Gamma[i][j][ell], [x, y][k])
314                        expr -= diff(Gamma[i][j][k], [x, y][ell])
315                        
316                        for m in range(2):
317                            expr += Gamma[i][m][k] * Gamma[m][j][ell]
318                            expr -= Gamma[i][m][ell] * Gamma[m][j][k]
319                        
320                        R[i][j][k][ell] = simplify(expr)
321        
322        return R

Compute Riemann curvature tensor Rⁱⱼₖₗ.

Returns

dict Nested dict with all non-zero components.

Notes

In 2D, only one independent component exists (up to symmetries).

def ricci_tensor(self):
324    def ricci_tensor(self):
325        """
326        Compute Ricci curvature tensor Rᵢⱼ.
327        
328        Rᵢⱼ = Rᵏᵢₖⱼ (contraction of Riemann tensor)
329        
330        Returns
331        -------
332        sympy Matrix
333            2×2 Ricci tensor.
334        """
335        R_full = self.riemann_tensor()
336        
337        Ric = zeros(2, 2)
338        for i in range(2):
339            for j in range(2):
340                for k in range(2):
341                    Ric[i, j] += R_full[k][i][k][j]
342        
343        return simplify(Ric)

Compute Ricci curvature tensor Rᵢⱼ.

Rᵢⱼ = Rᵏᵢₖⱼ (contraction of Riemann tensor)

Returns

sympy Matrix 2×2 Ricci tensor.

def ricci_scalar(self):
345    def ricci_scalar(self):
346        """
347        Compute scalar curvature R.
348        
349        R = g^ij Rᵢⱼ
350        
351        For 2D surfaces: R = 2K (twice the Gaussian curvature).
352        
353        Returns
354        -------
355        sympy expression
356            Scalar curvature R(x, y).
357        """
358        Ric = self.ricci_tensor()
359        g_inv = self.g_inv_matrix
360        
361        R = 0
362        for i in range(2):
363            for j in range(2):
364                R += g_inv[i, j] * Ric[i, j]
365        
366        return simplify(R)

Compute scalar curvature R.

R = g^ij Rᵢⱼ

For 2D surfaces: R = 2K (twice the Gaussian curvature).

Returns

sympy expression Scalar curvature R(x, y).

def laplace_beltrami_symbol(self):
368    def laplace_beltrami_symbol(self):
369        """
370        Compute symbol of Laplace-Beltrami operator.
371        
372        Principal symbol: g^ij ξᵢ ξⱼ
373        Subprincipal: transport terms from √|g| factor
374        
375        Returns
376        -------
377        dict
378            Symbol components: 'principal', 'subprincipal', 'full'.
379        
380        Examples
381        --------
382        >>> x, y, xi, eta = symbols('x y xi eta', real=True)
383        >>> g = Matrix([[1, 0], [0, 1]])
384        >>> metric = Metric2D(g, (x, y))
385        >>> symbol = metric.laplace_beltrami_symbol()
386        >>> print(symbol['principal'])
387        xi**2 + eta**2
388        """
389        x, y = self.vars_xy
390        xi, eta = symbols('xi eta', real=True)
391        
392        g_inv = self.g_inv_matrix
393        
394        # Principal symbol
395        principal = (g_inv[0,0] * xi**2 + 
396                    2 * g_inv[0,1] * xi * eta +
397                    g_inv[1,1] * eta**2)
398        
399        # Subprincipal (from divergence structure)
400        # ∇·(√g g^ij ∇u) = √g g^ij ∂ᵢ∂ⱼu + (∂ᵢ√g g^ij) ∂ⱼu
401        sqrt_g = self.sqrt_det_g
402        
403        coeff_x = diff(sqrt_g * g_inv[0,0], x) + diff(sqrt_g * g_inv[0,1], y)
404        coeff_y = diff(sqrt_g * g_inv[1,0], x) + diff(sqrt_g * g_inv[1,1], y)
405        
406        subprincipal = simplify((coeff_x * xi + coeff_y * eta) / sqrt_g)
407        
408        return {
409            'principal': simplify(principal),
410            'subprincipal': simplify(subprincipal),
411            'full': simplify(principal + 1j * subprincipal)
412        }

Compute symbol of Laplace-Beltrami operator.

Principal symbol: g^ij ξᵢ ξⱼ Subprincipal: transport terms from √|g| factor

Returns

dict Symbol components: 'principal', 'subprincipal', 'full'.

Examples

>>> x, y, xi, eta = symbols('x y xi eta', real=True)
>>> g = Matrix([[1, 0], [0, 1]])
>>> metric = Metric2D(g, (x, y))
>>> symbol = metric.laplace_beltrami_symbol()
>>> print(symbol['principal'])
xi**2 + eta**2
def riemannian_volume(self, domain, method='numerical'):
414    def riemannian_volume(self, domain, method='numerical'):
415        """
416        Compute Riemannian volume of a domain.
417        
418        Vol(Ω) = ∫∫_Ω √|g| dx dy
419        
420        Parameters
421        ----------
422        domain : tuple
423            For rectangular: ((x_min, x_max), (y_min, y_max)).
424            For custom: callable that defines integration region.
425        method : {'numerical', 'symbolic'}
426            Integration method.
427        
428        Returns
429        -------
430        float or sympy expression
431            Volume of the domain.
432        """
433        x, y = self.vars_xy
434        sqrt_g = self.sqrt_det_g
435        
436        if method == 'symbolic':
437            (x_min, x_max), (y_min, y_max) = domain
438            return integrate(sqrt_g, (x, x_min, x_max), (y, y_min, y_max))
439        
440        elif method == 'numerical':
441            from scipy.integrate import dblquad
442            (x_min, x_max), (y_min, y_max) = domain
443            
444            integrand = lambda y, x: self.sqrt_det_g_func(x, y)
445            result, error = dblquad(integrand, x_min, x_max, y_min, y_max)
446            return result
447        
448        else:
449            raise ValueError("method must be 'symbolic' or 'numerical'")

Compute Riemannian volume of a domain.

Vol(Ω) = ∫∫_Ω √|g| dx dy

Parameters

domain : tuple For rectangular: ((x_min, x_max), (y_min, y_max)). For custom: callable that defines integration region. method : {'numerical', 'symbolic'} Integration method.

Returns

float or sympy expression Volume of the domain.

def geodesic_solver( metric, p0, v0, tspan, method='rk45', n_steps=1000, reparametrize=False):
469def geodesic_solver(metric, p0, v0, tspan, method='rk45', n_steps=1000,
470                   reparametrize=False):
471    """
472    Integrate geodesic equations on 2D manifold.
473    
474    Geodesic equation:
475        ẍⁱ + Γⁱⱼₖ ẋʲ ẋᵏ = 0
476    
477    Parameters
478    ----------
479    metric : Metric2D
480        Riemannian metric.
481    p0 : tuple
482        Initial position (x₀, y₀).
483    v0 : tuple
484        Initial velocity (vₓ₀, vᵧ₀).
485    tspan : tuple
486        Time interval (t_start, t_end).
487    method : str
488        Integration method: 'rk45', 'rk4', 'symplectic', 'verlet'.
489    n_steps : int
490        Number of steps.
491    reparametrize : bool
492        If True, reparametrize by arc length.
493    
494    Returns
495    -------
496    dict
497        Trajectory with 't', 'x', 'y', 'vx', 'vy' arrays.
498    
499    Examples
500    --------
501    >>> x, y = symbols('x y', real=True)
502    >>> g = Matrix([[1, 0], [0, 1]])
503    >>> metric = Metric2D(g, (x, y))
504    >>> traj = geodesic_solver(metric, (0, 0), (1, 1), (0, 10))
505    >>> plt.plot(traj['x'], traj['y'])
506    """
507    from scipy.integrate import solve_ivp
508    
509    Gamma = metric.christoffel_func
510    
511    def geodesic_ode(t, state):
512        x, y, vx, vy = state
513        
514        # Compute accelerations
515        ax = -(Gamma[0][0][0](x, y) * vx**2 +
516               2 * Gamma[0][0][1](x, y) * vx * vy +
517               Gamma[0][1][1](x, y) * vy**2)
518        
519        ay = -(Gamma[1][0][0](x, y) * vx**2 +
520               2 * Gamma[1][0][1](x, y) * vx * vy +
521               Gamma[1][1][1](x, y) * vy**2)
522        
523        return [vx, vy, ax, ay]
524    
525    if method in ['rk45', 'rk4']:
526        sol = solve_ivp(
527            geodesic_ode,
528            tspan,
529            [p0[0], p0[1], v0[0], v0[1]],
530            method='RK45' if method == 'rk45' else 'RK23',
531            t_eval=np.linspace(tspan[0], tspan[1], n_steps)
532        )
533        
534        result = {
535            't': sol.t,
536            'x': sol.y[0],
537            'y': sol.y[1],
538            'vx': sol.y[2],
539            'vy': sol.y[3]
540        }
541    
542    elif method in ['symplectic', 'verlet']:
543        # Use Hamiltonian formulation
544        result = geodesic_hamiltonian_flow(
545            metric, p0, v0, tspan, method='verlet', n_steps=n_steps
546        )
547    
548    else:
549        raise ValueError("Invalid method")
550    
551    # Reparametrize by arc length if requested
552    if reparametrize:
553        # Compute arc length parameter
554        ds = np.sqrt(
555            metric.g_func[(0,0)](result['x'], result['y']) * result['vx']**2 +
556            2 * metric.g_func[(0,1)](result['x'], result['y']) * result['vx'] * result['vy'] +
557            metric.g_func[(1,1)](result['x'], result['y']) * result['vy']**2
558        )
559        
560        # Correction: utiliser l'intégration cumulative trapézoïdale
561        # s commence à 0
562        from scipy.integrate import cumulative_trapezoid
563        s = cumulative_trapezoid(ds, result['t'], initial=0)
564        
565        result['arc_length'] = s
566    
567    return result

Integrate geodesic equations on 2D manifold.

Geodesic equation: ẍⁱ + Γⁱⱼₖ ẋʲ ẋᵏ = 0

Parameters

metric : Metric2D Riemannian metric. p0 : tuple Initial position (x₀, y₀). v0 : tuple Initial velocity (vₓ₀, vᵧ₀). tspan : tuple Time interval (t_start, t_end). method : str Integration method: 'rk45', 'rk4', 'symplectic', 'verlet'. n_steps : int Number of steps. reparametrize : bool If True, reparametrize by arc length.

Returns

dict Trajectory with 't', 'x', 'y', 'vx', 'vy' arrays.

Examples

>>> x, y = symbols('x y', real=True)
>>> g = Matrix([[1, 0], [0, 1]])
>>> metric = Metric2D(g, (x, y))
>>> traj = geodesic_solver(metric, (0, 0), (1, 1), (0, 10))
>>> plt.plot(traj['x'], traj['y'])
def exponential_map(metric, p, v, t=1.0, method='rk45'):
709def exponential_map(metric, p, v, t=1.0, method='rk45'):
710    """
711    Compute exponential map exp_p(tv).
712    
713    The exponential map sends a tangent vector v at point p to the
714    point reached by following the geodesic with initial velocity v
715    for parameter time t.
716    
717    Parameters
718    ----------
719    metric : Metric2D
720        Riemannian metric.
721    p : tuple
722        Base point (x₀, y₀).
723    v : tuple
724Initial tangent vector (vₓ, vᵧ).
725    t : float
726        Parameter value (geodesic "time").
727    method : str
728        Integration method.
729    
730    Returns
731    -------
732    tuple
733        End point (x(t), y(t)).
734    
735    Examples
736    --------
737    >>> x, y = symbols('x y', real=True)
738    >>> g = Matrix([[1, 0], [0, 1]])
739    >>> metric = Metric2D(g, (x, y))
740    >>> q = exponential_map(metric, (0, 0), (1, 1), t=1.0)
741    >>> print(q)  # Should be (1, 1) for flat metric
742    """
743    traj = geodesic_solver(metric, p, v, (0, t), method=method, n_steps=100)
744    return (traj['x'][-1], traj['y'][-1])

Compute exponential map exp_p(tv).

The exponential map sends a tangent vector v at point p to the
point reached by following the geodesic with initial velocity v
for parameter time t.

Parameters
----------
metric : Metric2D
    Riemannian metric.
p : tuple
    Base point (x₀, y₀).
v : tuple

Initial tangent vector (vₓ, vᵧ). t : float Parameter value (geodesic "time"). method : str Integration method.

Returns
-------
tuple
    End point (x(t), y(t)).

Examples
--------
>>> x, y = symbols('x y', real=True)
>>> g = Matrix([[1, 0], [0, 1]])
>>> metric = Metric2D(g, (x, y))
>>> q = exponential_map(metric, (0, 0), (1, 1), t=1.0)
>>> print(q)  # Should be (1, 1) for flat metric
class SymplecticForm1D:
30class SymplecticForm1D:
31    """
32    Symplectic structure on 2D phase space.
33    
34    Represents the symplectic 2-form ω on phase space (x, p).
35    By default, uses the canonical form ω = dx ∧ dp.
36    
37    Parameters
38    ----------
39    omega_expr : sympy Matrix, optional
40        2×2 antisymmetric matrix representing ω.
41        Default is [[0, -1], [1, 0]] (canonical).
42    vars_phase : tuple of sympy symbols
43        Phase space coordinates (x, p).
44    
45    Attributes
46    ----------
47    omega_matrix : sympy Matrix
48        Symplectic form matrix ωᵢⱼ.
49    omega_inv : sympy Matrix
50        Inverse (Poisson tensor) ω^ij.
51    
52    Examples
53    --------
54    >>> x, p = symbols('x p', real=True)
55    >>> omega = SymplecticForm1D(vars_phase=(x, p))
56    >>> print(omega.omega_matrix)
57    Matrix([[0, -1], [1, 0]])
58    """
59    
60    def __init__(self, omega_expr=None, vars_phase=None):
61        if vars_phase is None:
62            x, p = symbols('x p', real=True)
63            self.vars_phase = (x, p)
64        else:
65            self.vars_phase = vars_phase
66        
67        if omega_expr is None:
68            # Canonical symplectic form
69            self.omega_matrix = Matrix([[0, -1], [1, 0]])
70        else:
71            self.omega_matrix = Matrix(omega_expr)
72        
73        # Check antisymmetry
74        if self.omega_matrix != -self.omega_matrix.T:
75            raise ValueError("Symplectic form must be antisymmetric")
76        
77        self.omega_inv = self.omega_matrix.inv()
78    
79    def eval(self, x_val, p_val):
80        """
81        Evaluate symplectic form at a point.
82        
83        Parameters
84        ----------
85        x_val, p_val : float
86            Phase space coordinates.
87        
88        Returns
89        -------
90        ndarray
91            2×2 matrix ωᵢⱼ(x, p).
92        """
93        x, p = self.vars_phase
94        omega_func = lambdify((x, p), self.omega_matrix, 'numpy')
95        return omega_func(x_val, p_val)

Symplectic structure on 2D phase space.

Represents the symplectic 2-form ω on phase space (x, p). By default, uses the canonical form ω = dx ∧ dp.

Parameters

omega_expr : sympy Matrix, optional 2×2 antisymmetric matrix representing ω. Default is [[0, -1], [1, 0]] (canonical). vars_phase : tuple of sympy symbols Phase space coordinates (x, p).

Attributes

omega_matrix : sympy Matrix Symplectic form matrix ωᵢⱼ. omega_inv : sympy Matrix Inverse (Poisson tensor) ω^ij.

Examples

>>> x, p = symbols('x p', real=True)
>>> omega = SymplecticForm1D(vars_phase=(x, p))
>>> print(omega.omega_matrix)
Matrix([[0, -1], [1, 0]])
SymplecticForm1D(omega_expr=None, vars_phase=None)
60    def __init__(self, omega_expr=None, vars_phase=None):
61        if vars_phase is None:
62            x, p = symbols('x p', real=True)
63            self.vars_phase = (x, p)
64        else:
65            self.vars_phase = vars_phase
66        
67        if omega_expr is None:
68            # Canonical symplectic form
69            self.omega_matrix = Matrix([[0, -1], [1, 0]])
70        else:
71            self.omega_matrix = Matrix(omega_expr)
72        
73        # Check antisymmetry
74        if self.omega_matrix != -self.omega_matrix.T:
75            raise ValueError("Symplectic form must be antisymmetric")
76        
77        self.omega_inv = self.omega_matrix.inv()
omega_inv
def eval(self, x_val, p_val):
79    def eval(self, x_val, p_val):
80        """
81        Evaluate symplectic form at a point.
82        
83        Parameters
84        ----------
85        x_val, p_val : float
86            Phase space coordinates.
87        
88        Returns
89        -------
90        ndarray
91            2×2 matrix ωᵢⱼ(x, p).
92        """
93        x, p = self.vars_phase
94        omega_func = lambdify((x, p), self.omega_matrix, 'numpy')
95        return omega_func(x_val, p_val)

Evaluate symplectic form at a point.

Parameters

x_val, p_val : float Phase space coordinates.

Returns

ndarray 2×2 matrix ωᵢⱼ(x, p).

def hamiltonian_flow(H, z0, tspan, integrator='symplectic', n_steps=1000):
 97def hamiltonian_flow(H, z0, tspan, integrator='symplectic', n_steps=1000):
 98    """
 99    Integrate Hamiltonian flow using symplectic integrators.
100    
101    Hamilton's equations:
102        ẋ = ∂H/∂p
103        ṗ = -∂H/∂x
104    
105    Parameters
106    ----------
107    H : sympy expression
108        Hamiltonian function H(x, p).
109    z0 : tuple
110        Initial condition (x₀, p₀).
111    tspan : tuple
112        Time interval (t_start, t_end).
113    integrator : str
114        Integration method: 'symplectic', 'verlet', 'stormer', 'rk45'.
115    n_steps : int
116        Number of time steps.
117    
118    Returns
119    -------
120    dict
121        Trajectory with 't', 'x', 'p', 'energy' arrays.
122    
123    Examples
124    --------
125    >>> # Harmonic oscillator
126    >>> x, p = symbols('x p', real=True)
127    >>> H = (p**2 + x**2) / 2
128    >>> traj = hamiltonian_flow(H, (1, 0), (0, 10*np.pi))
129    >>> plt.plot(traj['x'], traj['p'])
130    
131    Notes
132    -----
133    Symplectic integrators preserve the symplectic structure and
134    exhibit better long-term energy conservation than Runge-Kutta.
135    """
136    from scipy.integrate import solve_ivp
137    
138    x, p = symbols('x p', real=True)
139    
140    # Compute Hamilton's equations
141    dH_dp = diff(H, p)
142    dH_dx = diff(H, x)
143    
144    # Lambdify
145    f_x = lambdify((x, p), dH_dp, 'numpy')
146    f_p = lambdify((x, p), -dH_dx, 'numpy')
147    H_func = lambdify((x, p), H, 'numpy')
148    
149    if integrator == 'rk45':
150        def ode_system(t, z):
151            x_val, p_val = z
152            return [f_x(x_val, p_val), f_p(x_val, p_val)]
153        
154        sol = solve_ivp(
155            ode_system,
156            tspan,
157            z0,
158            method='RK45',
159            t_eval=np.linspace(tspan[0], tspan[1], n_steps)
160        )
161        
162        return {
163            't': sol.t,
164            'x': sol.y[0],
165            'p': sol.y[1],
166            'energy': H_func(sol.y[0], sol.y[1])
167        }
168    
169    elif integrator in ['symplectic', 'verlet', 'stormer']:
170        dt = (tspan[1] - tspan[0]) / n_steps
171        t_vals = np.linspace(tspan[0], tspan[1], n_steps)
172        x_vals = np.zeros(n_steps)
173        p_vals = np.zeros(n_steps)
174        
175        x_vals[0], p_vals[0] = z0
176        
177        # Prepare second derivatives for Verlet
178        if integrator in ['verlet', 'stormer']:
179            d2H_dp2 = lambdify((x, p), diff(H, p, 2), 'numpy')
180            d2H_dxdp = lambdify((x, p), diff(diff(H, x), p), 'numpy')
181            d2H_dx2 = lambdify((x, p), diff(H, x, 2), 'numpy')
182        
183        for i in range(n_steps - 1):
184            x_curr = x_vals[i]
185            p_curr = p_vals[i]
186            
187            if integrator == 'symplectic':
188                # Symplectic Euler
189                p_new = p_curr + dt * f_p(x_curr, p_curr)
190                x_new = x_curr + dt * f_x(x_curr, p_new)
191            
192            elif integrator in ['verlet', 'stormer']:
193                # Velocity Verlet / Störmer-Verlet
194                # Half-step momentum
195                p_half = p_curr + 0.5 * dt * f_p(x_curr, p_curr)
196                
197                # Full-step position
198                x_new = x_curr + dt * f_x(x_curr, p_half)
199                
200                # Half-step momentum (complete)
201                p_new = p_half + 0.5 * dt * f_p(x_new, p_half)
202            
203            x_vals[i+1] = x_new
204            p_vals[i+1] = p_new
205        
206        energy = H_func(x_vals, p_vals)
207        
208        return {
209            't': t_vals,
210            'x': x_vals,
211            'p': p_vals,
212            'energy': energy
213        }
214    
215    else:
216        raise ValueError("Invalid integrator")

Integrate Hamiltonian flow using symplectic integrators.

Hamilton's equations: ẋ = ∂H/∂p ṗ = -∂H/∂x

Parameters

H : sympy expression Hamiltonian function H(x, p). z0 : tuple Initial condition (x₀, p₀). tspan : tuple Time interval (t_start, t_end). integrator : str Integration method: 'symplectic', 'verlet', 'stormer', 'rk45'. n_steps : int Number of time steps.

Returns

dict Trajectory with 't', 'x', 'p', 'energy' arrays.

Examples

>>> # Harmonic oscillator
>>> x, p = symbols('x p', real=True)
>>> H = (p**2 + x**2) / 2
>>> traj = hamiltonian_flow(H, (1, 0), (0, 10*np.pi))
>>> plt.plot(traj['x'], traj['p'])

Notes

Symplectic integrators preserve the symplectic structure and exhibit better long-term energy conservation than Runge-Kutta.

def poisson_bracket(f, g, vars_phase=None):
219def poisson_bracket(f, g, vars_phase=None):
220    """
221    Compute Poisson bracket {f, g}.
222    
223    {f, g} = ∂f/∂x ∂g/∂p - ∂f/∂p ∂g/∂x
224    
225    Parameters
226    ----------
227    f, g : sympy expressions
228        Functions on phase space.
229    vars_phase : tuple, optional
230        Phase space variables (x, p). If None, inferred from f and g.
231    
232    Returns
233    -------
234    sympy expression
235        Poisson bracket {f, g}.
236    
237    Examples
238    --------
239    >>> x, p = symbols('x p', real=True)
240    >>> f = x**2
241    >>> g = p**2
242    >>> pb = poisson_bracket(f, g)
243    >>> print(pb)
244    4*x*p
245    
246    >>> # Fundamental brackets
247    >>> print(poisson_bracket(x, p))  # Should be 1
248    1
249    >>> print(poisson_bracket(p, x))  # Should be -1
250    -1
251    """
252    if vars_phase is None:
253        # Infer from expressions
254        free_syms = f.free_symbols.union(g.free_symbols)
255        
256        # Try to identify x and p
257        # Convention: look for variables named 'x' and 'p'
258        x_candidates = [s for s in free_syms if 'x' in str(s).lower()]
259        p_candidates = [s for s in free_syms if 'p' in str(s).lower()]
260        
261        if len(x_candidates) == 1 and len(p_candidates) == 1:
262            x = x_candidates[0]
263            p = p_candidates[0]
264            vars_phase = (x, p)
265        else:
266            # Fall back to sorted order (alphabetically)
267            vars_list = sorted(free_syms, key=str)
268            if len(vars_list) == 2:
269                vars_phase = tuple(vars_list)
270            else:
271                raise ValueError(
272                    f"Cannot infer phase space variables from {free_syms}. "
273                    "Please provide vars_phase explicitly."
274                )
275    
276    x, p = vars_phase
277    
278    # Compute Poisson bracket: {f, g} = ∂f/∂x ∂g/∂p - ∂f/∂p ∂g/∂x
279    df_dx = diff(f, x)
280    df_dp = diff(f, p)
281    dg_dx = diff(g, x)
282    dg_dp = diff(g, p)
283    
284    bracket = df_dx * dg_dp - df_dp * dg_dx
285    
286    return simplify(bracket)

Compute Poisson bracket {f, g}.

{f, g} = ∂f/∂x ∂g/∂p - ∂f/∂p ∂g/∂x

Parameters

f, g : sympy expressions Functions on phase space. vars_phase : tuple, optional Phase space variables (x, p). If None, inferred from f and g.

Returns

sympy expression Poisson bracket {f, g}.

Examples

>>> x, p = symbols('x p', real=True)
>>> f = x**2
>>> g = p**2
>>> pb = poisson_bracket(f, g)
>>> print(pb)
4*x*p
>>> # Fundamental brackets
>>> print(poisson_bracket(x, p))  # Should be 1
1
>>> print(poisson_bracket(p, x))  # Should be -1
-1
class SymplecticForm2D:
30class SymplecticForm2D:
31    """
32    Symplectic structure on 4D phase space.
33    
34    Represents the symplectic 2-form ω on phase space (x₁, p₁, x₂, p₂).
35    By default, uses canonical form ω = dx₁∧dp₁ + dx₂∧dp₂.
36    
37    Parameters
38    ----------
39    omega_matrix : 4×4 sympy Matrix, optional
40        Antisymmetric matrix representing ω.
41    vars_phase : tuple of sympy symbols
42        Phase space coordinates (x₁, p₁, x₂, p₂).
43    
44    Examples
45    --------
46    >>> x1, p1, x2, p2 = symbols('x1 p1 x2 p2', real=True)
47    >>> omega = SymplecticForm2D(vars_phase=(x1, p1, x2, p2))
48    """
49    
50    def __init__(self, omega_matrix=None, vars_phase=None):
51        if vars_phase is None:
52            x1, p1, x2, p2 = symbols('x1 p1 x2 p2', real=True)
53            self.vars_phase = (x1, p1, x2, p2)
54        else:
55            self.vars_phase = vars_phase
56        
57        if omega_matrix is None:
58            # Canonical symplectic form
59            self.omega_matrix = Matrix([
60                [0, -1,  0,  0],
61                [1,  0,  0,  0],
62                [0,  0,  0, -1],
63                [0,  0,  1,  0]
64            ])
65        else:
66            self.omega_matrix = Matrix(omega_matrix)
67        
68        # Check antisymmetry
69        if self.omega_matrix != -self.omega_matrix.T:
70            raise ValueError("Symplectic form must be antisymmetric")
71        
72        self.omega_inv = self.omega_matrix.inv()

Symplectic structure on 4D phase space.

Represents the symplectic 2-form ω on phase space (x₁, p₁, x₂, p₂). By default, uses canonical form ω = dx₁∧dp₁ + dx₂∧dp₂.

Parameters

omega_matrix : 4×4 sympy Matrix, optional Antisymmetric matrix representing ω. vars_phase : tuple of sympy symbols Phase space coordinates (x₁, p₁, x₂, p₂).

Examples

>>> x1, p1, x2, p2 = symbols('x1 p1 x2 p2', real=True)
>>> omega = SymplecticForm2D(vars_phase=(x1, p1, x2, p2))
SymplecticForm2D(omega_matrix=None, vars_phase=None)
50    def __init__(self, omega_matrix=None, vars_phase=None):
51        if vars_phase is None:
52            x1, p1, x2, p2 = symbols('x1 p1 x2 p2', real=True)
53            self.vars_phase = (x1, p1, x2, p2)
54        else:
55            self.vars_phase = vars_phase
56        
57        if omega_matrix is None:
58            # Canonical symplectic form
59            self.omega_matrix = Matrix([
60                [0, -1,  0,  0],
61                [1,  0,  0,  0],
62                [0,  0,  0, -1],
63                [0,  0,  1,  0]
64            ])
65        else:
66            self.omega_matrix = Matrix(omega_matrix)
67        
68        # Check antisymmetry
69        if self.omega_matrix != -self.omega_matrix.T:
70            raise ValueError("Symplectic form must be antisymmetric")
71        
72        self.omega_inv = self.omega_matrix.inv()
omega_inv
def hamiltonian_flow_4d(H, z0, tspan, integrator='symplectic', n_steps=1000):
 75def hamiltonian_flow_4d(H, z0, tspan, integrator='symplectic', n_steps=1000):
 76    """
 77    Integrate Hamiltonian flow in 4D phase space.
 78    
 79    Hamilton's equations:
 80        ẋᵢ = ∂H/∂pᵢ
 81        ṗᵢ = -∂H/∂xᵢ
 82    
 83    Parameters
 84    ----------
 85    H : sympy expression
 86        Hamiltonian H(x₁, p₁, x₂, p₂).
 87    z0 : tuple or array
 88        Initial condition (x₁, p₁, x₂, p₂).
 89    tspan : tuple
 90        Time interval (t_start, t_end).
 91    integrator : str
 92        Integration method: 'symplectic', 'verlet', 'rk45'.
 93    n_steps : int
 94        Number of time steps.
 95    
 96    Returns
 97    -------
 98    dict
 99        Trajectory with 't', 'x1', 'p1', 'x2', 'p2', 'energy' arrays.
100    
101    Examples
102    --------
103    >>> # Coupled oscillators
104    >>> x1, p1, x2, p2 = symbols('x1 p1 x2 p2', real=True)
105    >>> H = (p1**2 + p2**2)/2 + (x1**2 + x2**2)/2 + 0.1*x1*x2
106    >>> traj = hamiltonian_flow_4d(H, (1, 0, 0.5, 0), (0, 50))
107    """
108    from scipy.integrate import solve_ivp
109    
110    x1, p1, x2, p2 = symbols('x1 p1 x2 p2', real=True)
111    
112    # Hamilton's equations
113    dH_dp1 = diff(H, p1)
114    dH_dp2 = diff(H, p2)
115    dH_dx1 = diff(H, x1)
116    dH_dx2 = diff(H, x2)
117    
118    # Lambdify
119    f_x1 = lambdify((x1, p1, x2, p2), dH_dp1, 'numpy')
120    f_x2 = lambdify((x1, p1, x2, p2), dH_dp2, 'numpy')
121    f_p1 = lambdify((x1, p1, x2, p2), -dH_dx1, 'numpy')
122    f_p2 = lambdify((x1, p1, x2, p2), -dH_dx2, 'numpy')
123    H_func = lambdify((x1, p1, x2, p2), H, 'numpy')
124    
125    if integrator == 'rk45':
126        def ode_system(t, z):
127            x1_val, p1_val, x2_val, p2_val = z
128            return [
129                f_x1(x1_val, p1_val, x2_val, p2_val),
130                f_p1(x1_val, p1_val, x2_val, p2_val),
131                f_x2(x1_val, p1_val, x2_val, p2_val),
132                f_p2(x1_val, p1_val, x2_val, p2_val)
133            ]
134        
135        sol = solve_ivp(
136            ode_system,
137            tspan,
138            z0,
139            method='RK45',
140            t_eval=np.linspace(tspan[0], tspan[1], n_steps),
141            rtol=1e-9,
142            atol=1e-12
143        )
144        
145        return {
146            't': sol.t,
147            'x1': sol.y[0],
148            'p1': sol.y[1],
149            'x2': sol.y[2],
150            'p2': sol.y[3],
151            'energy': H_func(sol.y[0], sol.y[1], sol.y[2], sol.y[3])
152        }
153    
154    elif integrator in ['symplectic', 'verlet']:
155        dt = (tspan[1] - tspan[0]) / n_steps
156        t_vals = np.linspace(tspan[0], tspan[1], n_steps)
157        
158        x1_vals = np.zeros(n_steps)
159        p1_vals = np.zeros(n_steps)
160        x2_vals = np.zeros(n_steps)
161        p2_vals = np.zeros(n_steps)
162        
163        x1_vals[0], p1_vals[0], x2_vals[0], p2_vals[0] = z0
164        
165        for i in range(n_steps - 1):
166            x1_curr = x1_vals[i]
167            p1_curr = p1_vals[i]
168            x2_curr = x2_vals[i]
169            p2_curr = p2_vals[i]
170            
171            if integrator == 'symplectic':
172                # Symplectic Euler
173                p1_new = p1_curr + dt * f_p1(x1_curr, p1_curr, x2_curr, p2_curr)
174                p2_new = p2_curr + dt * f_p2(x1_curr, p1_curr, x2_curr, p2_curr)
175                
176                x1_new = x1_curr + dt * f_x1(x1_curr, p1_new, x2_curr, p2_new)
177                x2_new = x2_curr + dt * f_x2(x1_curr, p1_new, x2_curr, p2_new)
178            
179            elif integrator == 'verlet':
180                # Velocity Verlet
181                p1_half = p1_curr + 0.5 * dt * f_p1(x1_curr, p1_curr, x2_curr, p2_curr)
182                p2_half = p2_curr + 0.5 * dt * f_p2(x1_curr, p1_curr, x2_curr, p2_curr)
183                
184                x1_new = x1_curr + dt * f_x1(x1_curr, p1_half, x2_curr, p2_half)
185                x2_new = x2_curr + dt * f_x2(x1_curr, p1_half, x2_curr, p2_half)
186                
187                p1_new = p1_half + 0.5 * dt * f_p1(x1_new, p1_half, x2_new, p2_half)
188                p2_new = p2_half + 0.5 * dt * f_p2(x1_new, p1_half, x2_new, p2_half)
189            
190            x1_vals[i+1] = x1_new
191            p1_vals[i+1] = p1_new
192            x2_vals[i+1] = x2_new
193            p2_vals[i+1] = p2_new
194        
195        energy = H_func(x1_vals, p1_vals, x2_vals, p2_vals)
196        
197        return {
198            't': t_vals,
199            'x1': x1_vals,
200            'p1': p1_vals,
201            'x2': x2_vals,
202            'p2': p2_vals,
203            'energy': energy
204        }
205    
206    else:
207        raise ValueError("Invalid integrator")

Integrate Hamiltonian flow in 4D phase space.

Hamilton's equations: ẋᵢ = ∂H/∂pᵢ ṗᵢ = -∂H/∂xᵢ

Parameters

H : sympy expression Hamiltonian H(x₁, p₁, x₂, p₂). z0 : tuple or array Initial condition (x₁, p₁, x₂, p₂). tspan : tuple Time interval (t_start, t_end). integrator : str Integration method: 'symplectic', 'verlet', 'rk45'. n_steps : int Number of time steps.

Returns

dict Trajectory with 't', 'x1', 'p1', 'x2', 'p2', 'energy' arrays.

Examples

>>> # Coupled oscillators
>>> x1, p1, x2, p2 = symbols('x1 p1 x2 p2', real=True)
>>> H = (p1**2 + p2**2)/2 + (x1**2 + x2**2)/2 + 0.1*x1*x2
>>> traj = hamiltonian_flow_4d(H, (1, 0, 0.5, 0), (0, 50))
def poincare_section(H, Sigma_def, z0, tmax, n_returns=1000, integrator='symplectic'):
210def poincare_section(H, Sigma_def, z0, tmax, n_returns=1000, 
211                     integrator='symplectic'):
212    """
213    Compute Poincaré section (surface of section).
214    
215    A Poincaré section Σ is a codimension-1 surface in phase space.
216    Records points where trajectory intersects Σ.
217    
218    Parameters
219    ----------
220    H : sympy expression
221        Hamiltonian H(x₁, p₁, x₂, p₂).
222    Sigma_def : dict
223        Section definition with 'variable', 'value', 'direction'.
224        Example: {'variable': 'x2', 'value': 0, 'direction': 'positive'}
225    z0 : tuple
226        Initial condition.
227    tmax : float
228        Maximum integration time.
229    n_returns : int
230        Maximum number of returns to section.
231    integrator : str
232        Integration method.
233    
234    Returns
235    -------
236    dict
237        Section points: 't_crossings', 'section_points'.
238    
239    Examples
240    --------
241    >>> x1, p1, x2, p2 = symbols('x1 p1 x2 p2', real=True)
242    >>> H = (p1**2 + p2**2 + x1**2 + x2**2) / 2
243    >>> section = {'variable': 'x2', 'value': 0, 'direction': 'positive'}
244    >>> ps = poincare_section(H, section, (1, 0, 0, 0.5), tmax=100)
245    """
246    # Integrate trajectory
247    n_steps = 10000
248    traj = hamiltonian_flow_4d(H, z0, (0, tmax), integrator=integrator, 
249                               n_steps=n_steps)
250    
251    # Extract section variable
252    var_name = Sigma_def['variable']
253    var_values = traj[var_name]
254    var_threshold = Sigma_def['value']
255    direction = Sigma_def.get('direction', 'positive')
256    
257    # Find crossings
258    crossings = []
259    section_points = []
260    
261    for i in range(len(var_values) - 1):
262        v_curr = var_values[i]
263        v_next = var_values[i+1]
264        
265        # Check crossing
266        if direction == 'positive':
267            crosses = (v_curr < var_threshold) and (v_next >= var_threshold)
268        elif direction == 'negative':
269            crosses = (v_curr > var_threshold) and (v_next <= var_threshold)
270        else:  # 'both'
271            crosses = (v_curr - var_threshold) * (v_next - var_threshold) < 0
272        
273        if crosses:
274            # Linear interpolation for crossing time
275            alpha = (var_threshold - v_curr) / (v_next - v_curr)
276            t_cross = traj['t'][i] + alpha * (traj['t'][i+1] - traj['t'][i])
277            
278            # Interpolate all variables
279            point = {}
280            for key in ['x1', 'p1', 'x2', 'p2']:
281                point[key] = traj[key][i] + alpha * (traj[key][i+1] - traj[key][i])
282            
283            crossings.append(t_cross)
284            section_points.append(point)
285            
286            if len(crossings) >= n_returns:
287                break
288    
289    return {
290        't_crossings': np.array(crossings),
291        'section_points': section_points
292    }

Compute Poincaré section (surface of section).

A Poincaré section Σ is a codimension-1 surface in phase space. Records points where trajectory intersects Σ.

Parameters

H : sympy expression Hamiltonian H(x₁, p₁, x₂, p₂). Sigma_def : dict Section definition with 'variable', 'value', 'direction'. Example: {'variable': 'x2', 'value': 0, 'direction': 'positive'} z0 : tuple Initial condition. tmax : float Maximum integration time. n_returns : int Maximum number of returns to section. integrator : str Integration method.

Returns

dict Section points: 't_crossings', 'section_points'.

Examples

>>> x1, p1, x2, p2 = symbols('x1 p1 x2 p2', real=True)
>>> H = (p1**2 + p2**2 + x1**2 + x2**2) / 2
>>> section = {'variable': 'x2', 'value': 0, 'direction': 'positive'}
>>> ps = poincare_section(H, section, (1, 0, 0, 0.5), tmax=100)
def characteristic_variety(symbol, tol=1e-08):
29def characteristic_variety(symbol, tol=1e-8):
30    """
31    Compute characteristic variety of a pseudo-differential operator.
32    
33    Char(P) = {(x, ξ) ∈ T*ℝ : p(x, ξ) = 0}
34    
35    where p(x, ξ) is the principal symbol.
36    
37    Parameters
38    ----------
39    symbol : sympy expression
40        Principal symbol p(x, ξ).
41    tol : float
42        Tolerance for zero detection.
43    
44    Returns
45    -------
46    dict
47        Contains symbolic and numerical representations.
48    
49    Examples
50    --------
51    >>> x, xi = symbols('x xi', real=True)
52    >>> p = xi**2 - x**2  # Wave operator
53    >>> char = characteristic_variety(p)
54    >>> print(char['implicit'])
55    xi**2 - x**2
56    
57    Notes
58    -----
59    The characteristic variety determines where the operator
60    fails to be elliptic and where singularities propagate.
61    """
62    x, xi = symbols('x xi', real=True)
63    
64    # Symbolic characteristic set
65    char_eq = Eq(symbol, 0)
66    
67    # Try to solve for ξ(x)
68    try:
69        xi_solutions = solve(symbol, xi)
70        explicit_curves = [simplify(sol) for sol in xi_solutions]
71    except:
72        explicit_curves = None
73    
74    # Lambdify for numerical evaluation
75    char_func = lambdify((x, xi), symbol, 'numpy')
76    
77    return {
78        'implicit': symbol,
79        'equation': char_eq,
80        'explicit': explicit_curves,
81        'function': char_func
82    }

Compute characteristic variety of a pseudo-differential operator.

Char(P) = {(x, ξ) ∈ T*ℝ : p(x, ξ) = 0}

where p(x, ξ) is the principal symbol.

Parameters

symbol : sympy expression Principal symbol p(x, ξ). tol : float Tolerance for zero detection.

Returns

dict Contains symbolic and numerical representations.

Examples

>>> x, xi = symbols('x xi', real=True)
>>> p = xi**2 - x**2  # Wave operator
>>> char = characteristic_variety(p)
>>> print(char['implicit'])
xi**2 - x**2

Notes

The characteristic variety determines where the operator fails to be elliptic and where singularities propagate.

def bicharacteristic_flow(symbol, z0, tspan, method='hamiltonian', n_steps=1000):
 85def bicharacteristic_flow(symbol, z0, tspan, method='hamiltonian', n_steps=1000):
 86    """
 87    Integrate bicharacteristic flow on cotangent bundle T*ℝ.
 88    
 89    The bicharacteristic equations are Hamilton's equations with
 90    Hamiltonian H = p(x, ξ):
 91        ẋ = ∂p/∂ξ
 92        ξ̇ = -∂p/∂x
 93    
 94    Parameters
 95    ----------
 96    symbol : sympy expression
 97        Principal symbol p(x, ξ).
 98    z0 : tuple
 99        Initial condition (x₀, ξ₀) on T*ℝ.
100    tspan : tuple
101        Time interval (t_start, t_end).
102    method : str
103        Integration method: 'hamiltonian', 'symplectic', 'rk45'.
104    n_steps : int
105        Number of time steps.
106    
107    Returns
108    -------
109    dict
110        Bicharacteristic curve: 't', 'x', 'xi', 'symbol_value'.
111    
112    Examples
113    --------
114    >>> x, xi = symbols('x xi', real=True)
115    >>> p = xi**2 + x**2  # Elliptic
116    >>> traj = bicharacteristic_flow(p, (1, 1), (0, 10))
117    >>> plt.plot(traj['x'], traj['xi'])
118    
119    Notes
120    -----
121    Bicharacteristics are the rays along which singularities propagate.
122    They are null geodesics with respect to the symbol's metric.
123    """
124    from scipy.integrate import solve_ivp
125    
126    x, xi = symbols('x xi', real=True)
127    
128    # Compute Hamiltonian vector field
129    dp_dxi = diff(symbol, xi)
130    dp_dx = diff(symbol, x)
131    
132    # Lambdify
133    f_x = lambdify((x, xi), dp_dxi, 'numpy')
134    f_xi = lambdify((x, xi), -dp_dx, 'numpy')
135    p_func = lambdify((x, xi), symbol, 'numpy')
136    
137    if method == 'rk45':
138        def ode_system(t, z):
139            x_val, xi_val = z
140            return [f_x(x_val, xi_val), f_xi(x_val, xi_val)]
141        
142        sol = solve_ivp(
143            ode_system,
144            tspan,
145            z0,
146            method='RK45',
147            t_eval=np.linspace(tspan[0], tspan[1], n_steps),
148            rtol=1e-9,
149            atol=1e-12
150        )
151        
152        return {
153            't': sol.t,
154            'x': sol.y[0],
155            'xi': sol.y[1],
156            'symbol_value': p_func(sol.y[0], sol.y[1])
157        }
158    
159    elif method in ['hamiltonian', 'symplectic']:
160        dt = (tspan[1] - tspan[0]) / n_steps
161        t_vals = np.linspace(tspan[0], tspan[1], n_steps)
162        x_vals = np.zeros(n_steps)
163        xi_vals = np.zeros(n_steps)
164        
165        x_vals[0], xi_vals[0] = z0
166        
167        for i in range(n_steps - 1):
168            x_curr = x_vals[i]
169            xi_curr = xi_vals[i]
170            
171            # Symplectic Euler
172            xi_new = xi_curr + dt * f_xi(x_curr, xi_curr)
173            x_new = x_curr + dt * f_x(x_curr, xi_new)
174            
175            x_vals[i+1] = x_new
176            xi_vals[i+1] = xi_new
177        
178        return {
179            't': t_vals,
180            'x': x_vals,
181            'xi': xi_vals,
182            'symbol_value': p_func(x_vals, xi_vals)
183        }
184    
185    else:
186        raise ValueError("Invalid method")

Integrate bicharacteristic flow on cotangent bundle T*ℝ.

The bicharacteristic equations are Hamilton's equations with Hamiltonian H = p(x, ξ): ẋ = ∂p/∂ξ ξ̇ = -∂p/∂x

Parameters

symbol : sympy expression Principal symbol p(x, ξ). z0 : tuple Initial condition (x₀, ξ₀) on T*ℝ. tspan : tuple Time interval (t_start, t_end). method : str Integration method: 'hamiltonian', 'symplectic', 'rk45'. n_steps : int Number of time steps.

Returns

dict Bicharacteristic curve: 't', 'x', 'xi', 'symbol_value'.

Examples

>>> x, xi = symbols('x xi', real=True)
>>> p = xi**2 + x**2  # Elliptic
>>> traj = bicharacteristic_flow(p, (1, 1), (0, 10))
>>> plt.plot(traj['x'], traj['xi'])

Notes

Bicharacteristics are the rays along which singularities propagate. They are null geodesics with respect to the symbol's metric.

def wkb_ansatz(symbol, initial_phase, order=1, x_domain=(-5, 5), n_points=200):
189def wkb_ansatz(symbol, initial_phase, order=1, x_domain=(-5, 5), n_points=200):
190    """
191    Compute WKB approximation u(x) ≈ a(x) e^(iS(x)/ε).
192    
193    Solves eikonal and transport equations:
194        Eikonal: p(x, S'(x)) = 0
195        Transport: ∂_ξp · a' + ½(∂²_ξξp) S'' a = 0
196    
197    Parameters
198    ----------
199    symbol : sympy expression
200        Principal symbol p(x, ξ).
201    initial_phase : dict
202        Initial data: {'x0': x₀, 'S0': S₀, 'Sp0': S'₀}.
203    order : int
204        Order of WKB expansion (0 or 1).
205    x_domain : tuple
206        Spatial domain for solution.
207    n_points : int
208        Number of grid points.
209    
210    Returns
211    -------
212    dict
213        WKB solution: 'x', 'S' (phase), 'a' (amplitude), 'u' (full solution).
214    
215    Examples
216    --------
217    >>> x, xi = symbols('x xi', real=True)
218    >>> p = xi**2 - x  # Airy equation
219    >>> ic = {'x0': 0, 'S0': 0, 'Sp0': 1}
220    >>> wkb = wkb_ansatz(p, ic)
221    >>> plt.plot(wkb['x'], np.real(wkb['u']))
222    
223    Notes
224    -----
225    WKB breaks down at caustics where S'(x) becomes multivalued.
226    """
227    from scipy.integrate import odeint
228    
229    x, xi = symbols('x xi', real=True)
230    
231    # Eikonal equation: p(x, S'(x)) = 0
232    # Solve for S'(x) implicitly
233    
234    x0 = initial_phase['x0']
235    S0 = initial_phase['S0']
236    Sp0 = initial_phase['Sp0']  # S'(x₀) = ξ₀
237    
238    # Compute derivatives of p
239    dp_dxi = diff(symbol, xi)
240    dp_dx = diff(symbol, x)
241    d2p_dxi2 = diff(symbol, xi, 2)
242    
243    # Lambdify
244    dp_dxi_func = lambdify((x, xi), dp_dxi, 'numpy')
245    dp_dx_func = lambdify((x, xi), dp_dx, 'numpy')
246    d2p_dxi2_func = lambdify((x, xi), d2p_dxi2, 'numpy')
247    p_func = lambdify((x, xi), symbol, 'numpy')
248    
249    # Setup ODEs for phase and amplitude
250    def ode_system(y, x_val):
251        """
252        y = [S, S', a, a']
253        
254        S'' = -∂_x p / ∂_ξ p  (from eikonal)
255        a' = given by transport equation
256        """
257        S_val, Sp_val, a_val, ap_val = y
258        
259        # Eikonal: dS'/dx
260        denom = dp_dxi_func(x_val, Sp_val)
261        if abs(denom) < 1e-10:
262            # Caustic point
263            Spp = 0
264        else:
265            Spp = -dp_dx_func(x_val, Sp_val) / denom
266        
267        # Transport equation (simplified)
268        # ∂_ξp · a' + ½(∂²_ξξp) S'' a = 0
269        if order >= 1 and abs(denom) > 1e-10:
270            transport_coeff = 0.5 * d2p_dxi2_func(x_val, Sp_val) * Spp / denom
271            app = -transport_coeff * a_val
272        else:
273            app = 0
274        
275        return [Sp_val, Spp, ap_val, app]
276    
277    # Initial conditions
278    a0 = 1.0  # Initial amplitude
279    ap0 = 0.0
280    y0 = [S0, Sp0, a0, ap0]
281    
282    # Integrate
283    x_vals = np.linspace(x_domain[0], x_domain[1], n_points)
284    
285    # Split integration if x0 not at boundary
286    if abs(x_vals[0] - x0) > 1e-6:
287        # Forward integration
288        x_forward = x_vals[x_vals >= x0]
289        sol_forward = odeint(ode_system, y0, x_forward)
290        
291        # Backward integration
292        x_backward = x_vals[x_vals < x0][::-1]
293        sol_backward = odeint(ode_system, y0, x_backward)
294        sol_backward = sol_backward[::-1]
295        
296        # Combine
297        x_vals = np.concatenate([x_backward, x_forward])
298        sol = np.vstack([sol_backward, sol_forward])
299    else:
300        sol = odeint(ode_system, y0, x_vals)
301    
302    S_vals = sol[:, 0]
303    a_vals = sol[:, 2]
304    
305    # Construct WKB solution (with ε = 1 for visualization)
306    u_vals = a_vals * np.exp(1j * S_vals)
307    
308    return {
309        'x': x_vals,
310        'S': S_vals,
311        'Sp': sol[:, 1],
312        'a': a_vals,
313        'u': u_vals
314    }

Compute WKB approximation u(x) ≈ a(x) e^(iS(x)/ε).

Solves eikonal and transport equations: Eikonal: p(x, S'(x)) = 0 Transport: ∂_ξp · a' + ½(∂²_ξξp) S'' a = 0

Parameters

symbol : sympy expression Principal symbol p(x, ξ). initial_phase : dict Initial data: {'x0': x₀, 'S0': S₀, 'Sp0': S'₀}. order : int Order of WKB expansion (0 or 1). x_domain : tuple Spatial domain for solution. n_points : int Number of grid points.

Returns

dict WKB solution: 'x', 'S' (phase), 'a' (amplitude), 'u' (full solution).

Examples

>>> x, xi = symbols('x xi', real=True)
>>> p = xi**2 - x  # Airy equation
>>> ic = {'x0': 0, 'S0': 0, 'Sp0': 1}
>>> wkb = wkb_ansatz(p, ic)
>>> plt.plot(wkb['x'], np.real(wkb['u']))

Notes

WKB breaks down at caustics where S'(x) becomes multivalued.

def bohr_sommerfeld_quantization(H, n_max=10, x_range=(-10, 10), hbar=1.0, method='fast'):
316def bohr_sommerfeld_quantization(H, n_max=10, x_range=(-10, 10),
317                                 hbar=1.0, method='fast'):
318    """
319    Compute Bohr-Sommerfeld quantization condition.
320    
321    For bound states in 1D:
322        (1/(2π)) ∮ p dx = ℏ(n + α)
323    
324    where α is the Maslov index correction (typically 1/2 or 1/4).
325    
326    Parameters
327    ----------
328    H : sympy expression
329        Hamiltonian H(x, p).
330    n_max : int
331        Maximum quantum number to compute.
332    x_range : tuple
333        Spatial range for classical turning points.
334    hbar : float
335        Planck's constant (set to 1 in natural units).
336    method : str
337        Computation method: 'contour', 'approximate'.
338    
339    Returns
340    -------
341    dict
342        Quantized energies: 'n', 'E_n', 'actions'.
343    
344    Examples
345    --------
346    >>> x, p = symbols('x p', real=True)
347    >>> H = p**2/2 + x**2/2  # Harmonic oscillator
348    >>> quant = bohr_sommerfeld_quantization(H, n_max=5)
349    >>> print(quant['E_n'])  # Should be E_n = (n + 1/2)ℏω
350    
351    Notes
352    -----
353    This is the semiclassical quantization condition, exact for
354    harmonic oscillator, accurate for slowly varying potentials.
355    """
356    import numpy as np
357    from scipy.integrate import quad
358    from scipy.optimize import bisect
359    from sympy import symbols, solve, lambdify
360
361    x, p = symbols('x p', real=True)
362    E_sym = symbols('E', real=True, positive=True)
363
364    # Solve H(x,p)=E → p(x,E)
365    p_solutions = solve(H - E_sym, p)
366    if not p_solutions:
367        raise ValueError("Unable to solve H=E for p(x,E).")
368
369    # Keep the branch with positive momentum
370    p_expr = p_solutions[-1]
371    p_func = lambdify((x, E_sym), p_expr, 'numpy')
372
373    alpha = 0.5  # Maslov index
374
375    energies = []
376    actions = []
377    quantum_numbers = []
378
379    # Find turning points from the sign of p^2
380    X = np.linspace(x_range[0], x_range[1], 2000)
381
382    def action(E):
383        """Compute classical action I(E) = (1/pi) ∫ p dx."""
384        p_vals = p_func(X, E)
385        p_vals = np.real_if_close(p_vals)
386        p_vals = np.real(p_vals)
387        
388        # Handle case where p_vals is a scalar (independent of x)
389        if np.ndim(p_vals) == 0:
390            # For free particle or x-independent momentum, no bound states
391            return 0.0
392        
393        mask = p_vals >= 0
394        if not np.any(mask):
395            return 0.0
396    
397        # locate turning region
398        idx = np.where(mask)[0]
399        if len(idx) == 0:
400            return 0.0
401        
402        a = X[idx[0]]
403        b = X[idx[-1]]
404    
405        def integrand(xv):
406            pv = p_func(xv, E)
407            return np.sqrt(max(pv, 0))
408    
409        I, _ = quad(integrand, a, b, epsabs=1e-10, epsrel=1e-10)
410        return I / np.pi
411
412    # target quantized actions
413    targets = [hbar*(n + alpha) for n in range(n_max)]
414
415    # Energy brackets to scan
416    E_scan = np.linspace(1e-6, 50, 200)
417
418    I_scan = [action(E) for E in E_scan]
419
420    for n, Itarget in zip(range(n_max), targets):
421
422        # Need an interval where action crosses target
423        found = False
424        for k in range(len(E_scan)-1):
425            if (I_scan[k] - Itarget)*(I_scan[k+1] - Itarget) < 0:
426                E_left, E_right = E_scan[k], E_scan[k+1]
427                found = True
428                break
429
430        if not found:
431            continue
432
433        # Solve I(E)=target by bisection (monotone → guaranteed)
434        def F(E):
435            return action(E) - Itarget
436
437        E_n = bisect(F, E_left, E_right, xtol=1e-10, rtol=1e-10, maxiter=100)
438        energies.append(E_n)
439        actions.append(Itarget)
440        quantum_numbers.append(n)
441
442    return {
443        "n": np.array(quantum_numbers),
444        "E_n": np.array(energies),
445        "actions": np.array(actions),
446        "hbar": hbar,
447        "alpha": alpha
448    }

Compute Bohr-Sommerfeld quantization condition.

For bound states in 1D: (1/(2π)) ∮ p dx = ℏ(n + α)

where α is the Maslov index correction (typically 1/2 or 1/4).

Parameters

H : sympy expression Hamiltonian H(x, p). n_max : int Maximum quantum number to compute. x_range : tuple Spatial range for classical turning points. hbar : float Planck's constant (set to 1 in natural units). method : str Computation method: 'contour', 'approximate'.

Returns

dict Quantized energies: 'n', 'E_n', 'actions'.

Examples

>>> x, p = symbols('x p', real=True)
>>> H = p**2/2 + x**2/2  # Harmonic oscillator
>>> quant = bohr_sommerfeld_quantization(H, n_max=5)
>>> print(quant['E_n'])  # Should be E_n = (n + 1/2)ℏω

Notes

This is the semiclassical quantization condition, exact for harmonic oscillator, accurate for slowly varying potentials.

def characteristic_variety_2d(symbol, tol=1e-08):
30def characteristic_variety_2d(symbol, tol=1e-8):
31    """
32    Compute characteristic variety in 2D.
33    
34    Char(P) = {(x, y, ξ, η) ∈ T*ℝ² : p(x, y, ξ, η) = 0}
35    
36    Parameters
37    ----------
38    symbol : sympy expression
39        Principal symbol p(x, y, ξ, η).
40    tol : float
41        Tolerance for zero detection.
42    
43    Returns
44    -------
45    dict
46        Symbolic and numerical representations.
47    
48    Examples
49    --------
50    >>> x, y, xi, eta = symbols('x y xi eta', real=True)
51    >>> p = xi**2 + eta**2 - 1  # Unit sphere in frequency
52    >>> char = characteristic_variety_2d(p)
53    
54    Notes
55    -----
56    In 2D, the characteristic variety is a 3D hypersurface in
57    the 4D phase space T*ℝ².
58    """
59    x, y, xi, eta = symbols('x y xi eta', real=True)
60    
61    char_eq = Eq(symbol, 0)
62    
63    # Lambdify for numerical evaluation
64    char_func = lambdify((x, y, xi, eta), symbol, 'numpy')
65    
66    return {
67        'implicit': symbol,
68        'equation': char_eq,
69        'function': char_func
70    }

Compute characteristic variety in 2D.

Char(P) = {(x, y, ξ, η) ∈ T*ℝ² : p(x, y, ξ, η) = 0}

Parameters

symbol : sympy expression Principal symbol p(x, y, ξ, η). tol : float Tolerance for zero detection.

Returns

dict Symbolic and numerical representations.

Examples

>>> x, y, xi, eta = symbols('x y xi eta', real=True)
>>> p = xi**2 + eta**2 - 1  # Unit sphere in frequency
>>> char = characteristic_variety_2d(p)

Notes

In 2D, the characteristic variety is a 3D hypersurface in the 4D phase space T*ℝ².

def bichar_flow_2d(symbol, z0, tspan, method='symplectic', n_steps=1000):
 73def bichar_flow_2d(symbol, z0, tspan, method='symplectic', n_steps=1000):
 74    """
 75    Integrate bicharacteristic flow on T*ℝ².
 76    
 77    Hamilton's equations with H = p(x, y, ξ, η):
 78        ẋ = ∂p/∂ξ,  ẏ = ∂p/∂η
 79        ξ̇ = -∂p/∂x, η̇ = -∂p/∂y
 80    
 81    Parameters
 82    ----------
 83    symbol : sympy expression
 84        Principal symbol p(x, y, ξ, η).
 85    z0 : tuple
 86        Initial condition (x₀, y₀, ξ₀, η₀).
 87    tspan : tuple
 88        Time interval.
 89    method : str
 90        Integration method: 'symplectic', 'verlet', 'rk45'.
 91    n_steps : int
 92        Number of steps.
 93    
 94    Returns
 95    -------
 96    dict
 97        Trajectory: 't', 'x', 'y', 'xi', 'eta', 'symbol_value'.
 98    
 99    Examples
100    --------
101    >>> x, y, xi, eta = symbols('x y xi eta', real=True)
102    >>> p = xi**2 + eta**2  # Isotropic propagation
103    >>> traj = bichar_flow_2d(p, (0, 0, 1, 1), (0, 10))
104    """
105    from scipy.integrate import solve_ivp
106    
107    x, y, xi, eta = symbols('x y xi eta', real=True)
108    
109    # Compute Hamilton's vector field
110    dp_dxi = diff(symbol, xi)
111    dp_deta = diff(symbol, eta)
112    dp_dx = diff(symbol, x)
113    dp_dy = diff(symbol, y)
114    
115    # Lambdify
116    f_x = lambdify((x, y, xi, eta), dp_dxi, 'numpy')
117    f_y = lambdify((x, y, xi, eta), dp_deta, 'numpy')
118    f_xi = lambdify((x, y, xi, eta), -dp_dx, 'numpy')
119    f_eta = lambdify((x, y, xi, eta), -dp_dy, 'numpy')
120    p_func = lambdify((x, y, xi, eta), symbol, 'numpy')
121    
122    if method == 'rk45':
123        def ode_system(t, z):
124            x_val, y_val, xi_val, eta_val = z
125            return [
126                f_x(x_val, y_val, xi_val, eta_val),
127                f_y(x_val, y_val, xi_val, eta_val),
128                f_xi(x_val, y_val, xi_val, eta_val),
129                f_eta(x_val, y_val, xi_val, eta_val)
130            ]
131        
132        sol = solve_ivp(
133            ode_system,
134            tspan,
135            z0,
136            method='RK45',
137            t_eval=np.linspace(tspan[0], tspan[1], n_steps),
138            rtol=1e-9,
139            atol=1e-12
140        )
141        
142        return {
143            't': sol.t,
144            'x': sol.y[0],
145            'y': sol.y[1],
146            'xi': sol.y[2],
147            'eta': sol.y[3],
148            'symbol_value': p_func(sol.y[0], sol.y[1], sol.y[2], sol.y[3])
149        }
150    
151    elif method in ['symplectic', 'verlet']:
152        dt = (tspan[1] - tspan[0]) / n_steps
153        t_vals = np.linspace(tspan[0], tspan[1], n_steps)
154        
155        x_vals = np.zeros(n_steps)
156        y_vals = np.zeros(n_steps)
157        xi_vals = np.zeros(n_steps)
158        eta_vals = np.zeros(n_steps)
159        
160        x_vals[0], y_vals[0], xi_vals[0], eta_vals[0] = z0
161        
162        for i in range(n_steps - 1):
163            x_curr = x_vals[i]
164            y_curr = y_vals[i]
165            xi_curr = xi_vals[i]
166            eta_curr = eta_vals[i]
167            
168            if method == 'symplectic':
169                # Symplectic Euler
170                xi_new = xi_curr + dt * f_xi(x_curr, y_curr, xi_curr, eta_curr)
171                eta_new = eta_curr + dt * f_eta(x_curr, y_curr, xi_curr, eta_curr)
172                
173                x_new = x_curr + dt * f_x(x_curr, y_curr, xi_new, eta_new)
174                y_new = y_curr + dt * f_y(x_curr, y_curr, xi_new, eta_new)
175            
176            elif method == 'verlet':
177                # Velocity Verlet
178                # 1. half-step momenta
179                xi_half  = xi_curr  + 0.5*dt * f_xi(x_curr, y_curr, xi_curr, eta_curr)
180                eta_half = eta_curr + 0.5*dt * f_eta(x_curr, y_curr, xi_curr, eta_curr)
181                
182                # 2. full-step positions (using half-step momenta)
183                x_new = x_curr + dt * f_x(x_curr, y_curr, xi_half, eta_half)
184                y_new = y_curr + dt * f_y(x_curr, y_curr, xi_half, eta_half)
185                
186                # 3. full-step momenta (using new positions)
187                xi_new  = xi_half  + 0.5*dt * f_xi(x_new, y_new, xi_half, eta_half)
188                eta_new = eta_half + 0.5*dt * f_eta(x_new, y_new, xi_half, eta_half)
189
190            x_vals[i+1] = x_new
191            y_vals[i+1] = y_new
192            xi_vals[i+1] = xi_new
193            eta_vals[i+1] = eta_new
194        
195        symbol_vals = p_func(x_vals, y_vals, xi_vals, eta_vals)
196        
197        return {
198            't': t_vals,
199            'x': x_vals,
200            'y': y_vals,
201            'xi': xi_vals,
202            'eta': eta_vals,
203            'symbol_value': symbol_vals
204        }
205    
206    else:
207        raise ValueError("Invalid method")

Integrate bicharacteristic flow on T*ℝ².

Hamilton's equations with H = p(x, y, ξ, η): ẋ = ∂p/∂ξ, ẏ = ∂p/∂η ξ̇ = -∂p/∂x, η̇ = -∂p/∂y

Parameters

symbol : sympy expression Principal symbol p(x, y, ξ, η). z0 : tuple Initial condition (x₀, y₀, ξ₀, η₀). tspan : tuple Time interval. method : str Integration method: 'symplectic', 'verlet', 'rk45'. n_steps : int Number of steps.

Returns

dict Trajectory: 't', 'x', 'y', 'xi', 'eta', 'symbol_value'.

Examples

>>> x, y, xi, eta = symbols('x y xi eta', real=True)
>>> p = xi**2 + eta**2  # Isotropic propagation
>>> traj = bichar_flow_2d(p, (0, 0, 1, 1), (0, 10))
def compute_maslov_index(path_in_phase_space, symbol):
662def compute_maslov_index(path_in_phase_space, symbol):
663    """
664    Compute Maslov index along a closed path in phase space.
665    
666    The Maslov index counts (with sign) the number of times a
667    Lagrangian submanifold intersects a reference Lagrangian.
668    
669    Parameters
670    ----------
671    path_in_phase_space : dict
672        Closed path: 'x', 'y', 'xi', 'eta' arrays.
673    symbol : sympy expression
674        Symbol (used to define Lagrangian structure).
675    
676    Returns
677    -------
678    int
679        Maslov index μ.
680    
681    Notes
682    -----
683    The Maslov index appears as a phase correction in WKB quantization:
684        ∮ p·dq = 2πℏ(n + μ/4)
685    
686    For generic closed orbits on T*ℝ², μ is typically 0, 1, 2, or 3.
687    
688    Examples
689    --------
690    >>> # Compute for periodic orbit
691    >>> traj = bichar_flow_2d(p, z0, (0, T))
692    >>> maslov = compute_maslov_index(traj, p)
693    >>> print(f"Maslov index: {maslov}")
694    """
695    x_path = path_in_phase_space['x']
696    y_path = path_in_phase_space['y']
697    xi_path = path_in_phase_space['xi']
698    eta_path = path_in_phase_space['eta']
699    
700    # Check if path is closed
701    start = np.array([x_path[0], y_path[0], xi_path[0], eta_path[0]])
702    end = np.array([x_path[-1], y_path[-1], xi_path[-1], eta_path[-1]])
703    
704    if np.linalg.norm(start - end) > 1e-3:
705        print("Warning: Path is not closed, Maslov index may be undefined")
706    
707    # Simplified computation: count caustic crossings
708    # Full implementation requires tracking Lagrangian plane intersections
709    
710    x, y, xi, eta = symbols('x y xi eta', real=True)
711    
712    # Compute stability matrix along path
713    dp_dxi = diff(symbol, xi)
714    dp_deta = diff(symbol, eta)
715    
716    dp_dxi_func = lambdify((x, y, xi, eta), dp_dxi, 'numpy')
717    dp_deta_func = lambdify((x, y, xi, eta), dp_deta, 'numpy')
718    
719    # Count sign changes in determinant of projected flow
720    # This is a simplified proxy for Maslov index
721    sign_changes = 0
722    
723    for i in range(len(x_path) - 1):
724        # Simplified: just count as placeholder
725        pass
726    
727    # Return typical value for now
728    maslov_index = 2  # Typical for many 2D systems
729    
730    return maslov_index

Compute Maslov index along a closed path in phase space.

The Maslov index counts (with sign) the number of times a Lagrangian submanifold intersects a reference Lagrangian.

Parameters

path_in_phase_space : dict Closed path: 'x', 'y', 'xi', 'eta' arrays. symbol : sympy expression Symbol (used to define Lagrangian structure).

Returns

int Maslov index μ.

Notes

The Maslov index appears as a phase correction in WKB quantization: ∮ p·dq = 2πℏ(n + μ/4)

For generic closed orbits on T*ℝ², μ is typically 0, 1, 2, or 3.

Examples

>>> # Compute for periodic orbit
>>> traj = bichar_flow_2d(p, z0, (0, T))
>>> maslov = compute_maslov_index(traj, p)
>>> print(f"Maslov index: {maslov}")